Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ArrayRef.h +2 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/Functions.h +1427 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/MemoryOverlap.h +42 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/NativeMetaFunctions.h +1303 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/NumericUtils.h +203 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/OpaqueTensorImpl.h +187 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/Operators.h +1358 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/PTThreadPool.h +17 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ParallelNativeTBB.h +52 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/PythonTorchFunctionTLS.h +34 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/SequenceNumber.h +13 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/SparseTensorImpl.h +400 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/StorageUtils.h +49 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/TensorIndexing.h +735 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ThreadLocalPythonObjects.h +21 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ThreadLocalState.h +113 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/TypeDefault.h +30 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/Version.h +18 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/Atomic.cuh +508 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/CUDAApplyUtils.cuh +537 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/CUDAGeneratorImpl.h +138 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/CUDAGraph.h +92 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/CUDAGraphsUtils.cuh +57 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/CUDASparseBlas.h +318 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/CUDATensorMethods.cuh +15 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/NumericLimits.cuh +121 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/Sleep.h +10 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/detail/CUDAHooks.h +54 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/detail/DeviceThreadHandles.h +151 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/detail/IndexUtils.cuh +36 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/detail/IntegerDivider.cuh +124 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/detail/OffsetCalculator.cuh +119 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/detail/PhiloxCudaStateRaw.cuh +43 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/detail/TensorInfo.cuh +116 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/jiterator.h +40 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/tunable/GemmCommon.h +174 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/tunable/GemmHipblaslt.h +379 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/tunable/StreamTimer.h +34 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/TensorConversions.h +26 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/group_norm.h +42 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_cdist_forward_native.h +22 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_foreach_clamp_min_cpu_dispatch.h +28 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_log_softmax_meta_dispatch.h +25 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_make_per_tensor_quantized_tensor_cpu_dispatch.h +23 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_masked_softmax_backward_compositeexplicitautograd_dispatch.h +24 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_masked_softmax_backward_native.h +23 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_masked_softmax_cpu_dispatch.h +23 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_nnpack_spatial_convolution_ops.h +39 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_scaled_dot_product_efficient_attention_backward_ops.h +28 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_test_serialization_subcmul_ops.h +28 -0
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ArrayRef.h
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
#include <c10/util/ArrayRef.h>
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/Functions.h
ADDED
|
@@ -0,0 +1,1427 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
// @generated by torchgen/gen.py from Functions.h
|
| 4 |
+
|
| 5 |
+
#ifdef TORCH_ASSERT_NO_OPERATORS
|
| 6 |
+
#error This change adds a dependency on native_functions.yaml, \
|
| 7 |
+
meaning the file will need to be re-compiled every time an operator \
|
| 8 |
+
is changed or added. Consider if your change would be better placed in \
|
| 9 |
+
another file, or if a more specific header might achieve the same goal. \
|
| 10 |
+
See NOTE: [Tensor vs. TensorBase]
|
| 11 |
+
#endif
|
| 12 |
+
|
| 13 |
+
#if defined(AT_PER_OPERATOR_HEADERS) && defined(TORCH_ASSERT_ONLY_METHOD_OPERATORS)
|
| 14 |
+
#error This change adds a dependency on all pytorch operators, meaning the \
|
| 15 |
+
file will need to be re-compiled every time an operator is changed or added. \
|
| 16 |
+
Consider including a specific operator from <ATen/ops/{my_operator}.h> and \
|
| 17 |
+
see NOTE [TORCH_ASSERT_ONLY_METHOD_OPERATORS].
|
| 18 |
+
#endif
|
| 19 |
+
|
| 20 |
+
// NOTE: [TORCH_ASSERT_ONLY_METHOD_OPERATORS]
|
| 21 |
+
//
|
| 22 |
+
// In ATen, certain generated headers files include the definitions of
|
| 23 |
+
// every single operator in PyTorch. Unfortunately this means every
|
| 24 |
+
// time an operator signature is updated or changed in
|
| 25 |
+
// native_functions.yaml, you (and every other PyTorch developer) need
|
| 26 |
+
// to recompile every source file that includes any of these headers.
|
| 27 |
+
//
|
| 28 |
+
// To break up these header dependencies, and improve incremental
|
| 29 |
+
// build times for all PyTorch developers. These headers are split
|
| 30 |
+
// into per-operator headers in the `ATen/ops` folder. This limits
|
| 31 |
+
// incremental builds to only changes to methods of `Tensor`, or files
|
| 32 |
+
// that use the specific operator being changed. With `at::sum` as an
|
| 33 |
+
// example, you should include
|
| 34 |
+
//
|
| 35 |
+
// <ATen/ops/sum.h> // instead of ATen/Functions.h
|
| 36 |
+
// <ATen/ops/sum_native.h> // instead of ATen/NativeFunctions.h
|
| 37 |
+
// <ATen/ops/sum_ops.h> // instead of ATen/Operators.h
|
| 38 |
+
// <ATen/ops/sum_cpu_dispatch.h> // instead of ATen/CPUFunctions.h
|
| 39 |
+
//
|
| 40 |
+
// However, even if you're careful to use this in your own code.
|
| 41 |
+
// `Functions.h` might be included indirectly through another header
|
| 42 |
+
// without you realising. To avoid this, you can add
|
| 43 |
+
//
|
| 44 |
+
// #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
|
| 45 |
+
//
|
| 46 |
+
// to the top of your source file. This way any time the non-specific
|
| 47 |
+
// headers are included, the compiler will error out.
|
| 48 |
+
//
|
| 49 |
+
// Also, be aware that `ops` are not available in all build
|
| 50 |
+
// configurations (namely fb-internal) so you must guard these
|
| 51 |
+
// includes with `#ifdef AT_PER_OPERATOR_HEADERS`. e.g.
|
| 52 |
+
//
|
| 53 |
+
// #ifndef AT_PER_OPERATOR_HEADERS
|
| 54 |
+
// #include <ATen/Functions.h>
|
| 55 |
+
// #else
|
| 56 |
+
// #include <ATen/ops/sum.h>
|
| 57 |
+
// #endif
|
| 58 |
+
|
| 59 |
+
#include <ATen/Context.h>
|
| 60 |
+
#include <ATen/DeviceGuard.h>
|
| 61 |
+
#include <ATen/TensorUtils.h>
|
| 62 |
+
#include <ATen/TracerMode.h>
|
| 63 |
+
#include <ATen/core/Generator.h>
|
| 64 |
+
#include <ATen/core/Reduction.h>
|
| 65 |
+
#include <c10/core/SymInt.h>
|
| 66 |
+
#include <ATen/core/Tensor.h>
|
| 67 |
+
#include <c10/core/Scalar.h>
|
| 68 |
+
#include <c10/core/Storage.h>
|
| 69 |
+
#include <c10/core/TensorOptions.h>
|
| 70 |
+
#include <c10/util/Deprecated.h>
|
| 71 |
+
#include <c10/util/Optional.h>
|
| 72 |
+
#include <c10/util/OptionalArrayRef.h>
|
| 73 |
+
|
| 74 |
+
#include <ATen/ops/from_blob.h>
|
| 75 |
+
#include <ATen/ops/tensor.h>
|
| 76 |
+
|
| 77 |
+
#include <ATen/ops/_adaptive_avg_pool2d.h>
|
| 78 |
+
#include <ATen/ops/_adaptive_avg_pool2d_backward.h>
|
| 79 |
+
#include <ATen/ops/_adaptive_avg_pool3d.h>
|
| 80 |
+
#include <ATen/ops/_adaptive_avg_pool3d_backward.h>
|
| 81 |
+
#include <ATen/ops/_add_batch_dim.h>
|
| 82 |
+
#include <ATen/ops/_add_relu.h>
|
| 83 |
+
#include <ATen/ops/_addmm_activation.h>
|
| 84 |
+
#include <ATen/ops/_aminmax.h>
|
| 85 |
+
#include <ATen/ops/_amp_foreach_non_finite_check_and_unscale.h>
|
| 86 |
+
#include <ATen/ops/_amp_update_scale.h>
|
| 87 |
+
#include <ATen/ops/_assert_async.h>
|
| 88 |
+
#include <ATen/ops/_assert_scalar.h>
|
| 89 |
+
#include <ATen/ops/_assert_tensor_metadata.h>
|
| 90 |
+
#include <ATen/ops/_autocast_to_full_precision.h>
|
| 91 |
+
#include <ATen/ops/_autocast_to_reduced_precision.h>
|
| 92 |
+
#include <ATen/ops/_backward.h>
|
| 93 |
+
#include <ATen/ops/_batch_norm_impl_index.h>
|
| 94 |
+
#include <ATen/ops/_batch_norm_impl_index_backward.h>
|
| 95 |
+
#include <ATen/ops/_cast_Byte.h>
|
| 96 |
+
#include <ATen/ops/_cast_Char.h>
|
| 97 |
+
#include <ATen/ops/_cast_Double.h>
|
| 98 |
+
#include <ATen/ops/_cast_Float.h>
|
| 99 |
+
#include <ATen/ops/_cast_Half.h>
|
| 100 |
+
#include <ATen/ops/_cast_Int.h>
|
| 101 |
+
#include <ATen/ops/_cast_Long.h>
|
| 102 |
+
#include <ATen/ops/_cast_Short.h>
|
| 103 |
+
#include <ATen/ops/_cdist_backward.h>
|
| 104 |
+
#include <ATen/ops/_cdist_forward.h>
|
| 105 |
+
#include <ATen/ops/_cholesky_solve_helper.h>
|
| 106 |
+
#include <ATen/ops/_choose_qparams_per_tensor.h>
|
| 107 |
+
#include <ATen/ops/_chunk_cat.h>
|
| 108 |
+
#include <ATen/ops/_coalesce.h>
|
| 109 |
+
#include <ATen/ops/_coalesced.h>
|
| 110 |
+
#include <ATen/ops/_compute_linear_combination.h>
|
| 111 |
+
#include <ATen/ops/_conj.h>
|
| 112 |
+
#include <ATen/ops/_conj_copy.h>
|
| 113 |
+
#include <ATen/ops/_conj_physical.h>
|
| 114 |
+
#include <ATen/ops/_conv_depthwise2d.h>
|
| 115 |
+
#include <ATen/ops/_convert_indices_from_coo_to_csr.h>
|
| 116 |
+
#include <ATen/ops/_convert_indices_from_csr_to_coo.h>
|
| 117 |
+
#include <ATen/ops/_convert_weight_to_int4pack.h>
|
| 118 |
+
#include <ATen/ops/_convolution.h>
|
| 119 |
+
#include <ATen/ops/_convolution_double_backward.h>
|
| 120 |
+
#include <ATen/ops/_convolution_mode.h>
|
| 121 |
+
#include <ATen/ops/_copy_from.h>
|
| 122 |
+
#include <ATen/ops/_copy_from_and_resize.h>
|
| 123 |
+
#include <ATen/ops/_cslt_compress.h>
|
| 124 |
+
#include <ATen/ops/_cslt_sparse_mm.h>
|
| 125 |
+
#include <ATen/ops/_cslt_sparse_mm_search.h>
|
| 126 |
+
#include <ATen/ops/_ctc_loss.h>
|
| 127 |
+
#include <ATen/ops/_ctc_loss_backward.h>
|
| 128 |
+
#include <ATen/ops/_cudnn_ctc_loss.h>
|
| 129 |
+
#include <ATen/ops/_cudnn_init_dropout_state.h>
|
| 130 |
+
#include <ATen/ops/_cudnn_rnn.h>
|
| 131 |
+
#include <ATen/ops/_cudnn_rnn_backward.h>
|
| 132 |
+
#include <ATen/ops/_cudnn_rnn_flatten_weight.h>
|
| 133 |
+
#include <ATen/ops/_cufft_clear_plan_cache.h>
|
| 134 |
+
#include <ATen/ops/_cufft_get_plan_cache_max_size.h>
|
| 135 |
+
#include <ATen/ops/_cufft_get_plan_cache_size.h>
|
| 136 |
+
#include <ATen/ops/_cufft_set_plan_cache_max_size.h>
|
| 137 |
+
#include <ATen/ops/_cummax_helper.h>
|
| 138 |
+
#include <ATen/ops/_cummin_helper.h>
|
| 139 |
+
#include <ATen/ops/_debug_has_internal_overlap.h>
|
| 140 |
+
#include <ATen/ops/_dimI.h>
|
| 141 |
+
#include <ATen/ops/_dimV.h>
|
| 142 |
+
#include <ATen/ops/_dim_arange.h>
|
| 143 |
+
#include <ATen/ops/_dirichlet_grad.h>
|
| 144 |
+
#include <ATen/ops/_efficient_attention_backward.h>
|
| 145 |
+
#include <ATen/ops/_efficient_attention_forward.h>
|
| 146 |
+
#include <ATen/ops/_efficientzerotensor.h>
|
| 147 |
+
#include <ATen/ops/_embedding_bag.h>
|
| 148 |
+
#include <ATen/ops/_embedding_bag_backward.h>
|
| 149 |
+
#include <ATen/ops/_embedding_bag_dense_backward.h>
|
| 150 |
+
#include <ATen/ops/_embedding_bag_forward_only.h>
|
| 151 |
+
#include <ATen/ops/_embedding_bag_per_sample_weights_backward.h>
|
| 152 |
+
#include <ATen/ops/_embedding_bag_sparse_backward.h>
|
| 153 |
+
#include <ATen/ops/_empty_affine_quantized.h>
|
| 154 |
+
#include <ATen/ops/_empty_per_channel_affine_quantized.h>
|
| 155 |
+
#include <ATen/ops/_euclidean_dist.h>
|
| 156 |
+
#include <ATen/ops/_fake_quantize_learnable_per_channel_affine.h>
|
| 157 |
+
#include <ATen/ops/_fake_quantize_learnable_per_channel_affine_backward.h>
|
| 158 |
+
#include <ATen/ops/_fake_quantize_learnable_per_tensor_affine.h>
|
| 159 |
+
#include <ATen/ops/_fake_quantize_learnable_per_tensor_affine_backward.h>
|
| 160 |
+
#include <ATen/ops/_fake_quantize_per_tensor_affine_cachemask_tensor_qparams.h>
|
| 161 |
+
#include <ATen/ops/_fft_c2c.h>
|
| 162 |
+
#include <ATen/ops/_fft_c2r.h>
|
| 163 |
+
#include <ATen/ops/_fft_r2c.h>
|
| 164 |
+
#include <ATen/ops/_fill_mem_eff_dropout_mask.h>
|
| 165 |
+
#include <ATen/ops/_flash_attention_backward.h>
|
| 166 |
+
#include <ATen/ops/_flash_attention_forward.h>
|
| 167 |
+
#include <ATen/ops/_foobar.h>
|
| 168 |
+
#include <ATen/ops/_foreach_abs.h>
|
| 169 |
+
#include <ATen/ops/_foreach_acos.h>
|
| 170 |
+
#include <ATen/ops/_foreach_add.h>
|
| 171 |
+
#include <ATen/ops/_foreach_addcdiv.h>
|
| 172 |
+
#include <ATen/ops/_foreach_addcmul.h>
|
| 173 |
+
#include <ATen/ops/_foreach_asin.h>
|
| 174 |
+
#include <ATen/ops/_foreach_atan.h>
|
| 175 |
+
#include <ATen/ops/_foreach_ceil.h>
|
| 176 |
+
#include <ATen/ops/_foreach_clamp_max.h>
|
| 177 |
+
#include <ATen/ops/_foreach_clamp_min.h>
|
| 178 |
+
#include <ATen/ops/_foreach_copy.h>
|
| 179 |
+
#include <ATen/ops/_foreach_cos.h>
|
| 180 |
+
#include <ATen/ops/_foreach_cosh.h>
|
| 181 |
+
#include <ATen/ops/_foreach_div.h>
|
| 182 |
+
#include <ATen/ops/_foreach_erf.h>
|
| 183 |
+
#include <ATen/ops/_foreach_erfc.h>
|
| 184 |
+
#include <ATen/ops/_foreach_exp.h>
|
| 185 |
+
#include <ATen/ops/_foreach_expm1.h>
|
| 186 |
+
#include <ATen/ops/_foreach_floor.h>
|
| 187 |
+
#include <ATen/ops/_foreach_frac.h>
|
| 188 |
+
#include <ATen/ops/_foreach_lerp.h>
|
| 189 |
+
#include <ATen/ops/_foreach_lgamma.h>
|
| 190 |
+
#include <ATen/ops/_foreach_log.h>
|
| 191 |
+
#include <ATen/ops/_foreach_log10.h>
|
| 192 |
+
#include <ATen/ops/_foreach_log1p.h>
|
| 193 |
+
#include <ATen/ops/_foreach_log2.h>
|
| 194 |
+
#include <ATen/ops/_foreach_maximum.h>
|
| 195 |
+
#include <ATen/ops/_foreach_minimum.h>
|
| 196 |
+
#include <ATen/ops/_foreach_mul.h>
|
| 197 |
+
#include <ATen/ops/_foreach_neg.h>
|
| 198 |
+
#include <ATen/ops/_foreach_norm.h>
|
| 199 |
+
#include <ATen/ops/_foreach_pow.h>
|
| 200 |
+
#include <ATen/ops/_foreach_reciprocal.h>
|
| 201 |
+
#include <ATen/ops/_foreach_round.h>
|
| 202 |
+
#include <ATen/ops/_foreach_sigmoid.h>
|
| 203 |
+
#include <ATen/ops/_foreach_sign.h>
|
| 204 |
+
#include <ATen/ops/_foreach_sin.h>
|
| 205 |
+
#include <ATen/ops/_foreach_sinh.h>
|
| 206 |
+
#include <ATen/ops/_foreach_sqrt.h>
|
| 207 |
+
#include <ATen/ops/_foreach_sub.h>
|
| 208 |
+
#include <ATen/ops/_foreach_tan.h>
|
| 209 |
+
#include <ATen/ops/_foreach_tanh.h>
|
| 210 |
+
#include <ATen/ops/_foreach_trunc.h>
|
| 211 |
+
#include <ATen/ops/_foreach_zero.h>
|
| 212 |
+
#include <ATen/ops/_functional_assert_async.h>
|
| 213 |
+
#include <ATen/ops/_functional_assert_scalar.h>
|
| 214 |
+
#include <ATen/ops/_functional_sym_constrain_range.h>
|
| 215 |
+
#include <ATen/ops/_functional_sym_constrain_range_for_size.h>
|
| 216 |
+
#include <ATen/ops/_fused_adam.h>
|
| 217 |
+
#include <ATen/ops/_fused_adamw.h>
|
| 218 |
+
#include <ATen/ops/_fused_dropout.h>
|
| 219 |
+
#include <ATen/ops/_fused_moving_avg_obs_fq_helper.h>
|
| 220 |
+
#include <ATen/ops/_fused_sdp_choice.h>
|
| 221 |
+
#include <ATen/ops/_fused_sgd.h>
|
| 222 |
+
#include <ATen/ops/_fw_primal.h>
|
| 223 |
+
#include <ATen/ops/_fw_primal_copy.h>
|
| 224 |
+
#include <ATen/ops/_gather_sparse_backward.h>
|
| 225 |
+
#include <ATen/ops/_grid_sampler_2d_cpu_fallback.h>
|
| 226 |
+
#include <ATen/ops/_grid_sampler_2d_cpu_fallback_backward.h>
|
| 227 |
+
#include <ATen/ops/_has_compatible_shallow_copy_type.h>
|
| 228 |
+
#include <ATen/ops/_has_same_storage_numel.h>
|
| 229 |
+
#include <ATen/ops/_histogramdd_bin_edges.h>
|
| 230 |
+
#include <ATen/ops/_histogramdd_from_bin_cts.h>
|
| 231 |
+
#include <ATen/ops/_histogramdd_from_bin_tensors.h>
|
| 232 |
+
#include <ATen/ops/_index_put_impl.h>
|
| 233 |
+
#include <ATen/ops/_indices.h>
|
| 234 |
+
#include <ATen/ops/_indices_copy.h>
|
| 235 |
+
#include <ATen/ops/_int_mm.h>
|
| 236 |
+
#include <ATen/ops/_is_all_true.h>
|
| 237 |
+
#include <ATen/ops/_is_any_true.h>
|
| 238 |
+
#include <ATen/ops/_is_zerotensor.h>
|
| 239 |
+
#include <ATen/ops/_lazy_clone.h>
|
| 240 |
+
#include <ATen/ops/_linalg_check_errors.h>
|
| 241 |
+
#include <ATen/ops/_linalg_det.h>
|
| 242 |
+
#include <ATen/ops/_linalg_eigh.h>
|
| 243 |
+
#include <ATen/ops/_linalg_eigvals.h>
|
| 244 |
+
#include <ATen/ops/_linalg_slogdet.h>
|
| 245 |
+
#include <ATen/ops/_linalg_solve_ex.h>
|
| 246 |
+
#include <ATen/ops/_linalg_svd.h>
|
| 247 |
+
#include <ATen/ops/_local_scalar_dense.h>
|
| 248 |
+
#include <ATen/ops/_log_softmax.h>
|
| 249 |
+
#include <ATen/ops/_log_softmax_backward_data.h>
|
| 250 |
+
#include <ATen/ops/_logcumsumexp.h>
|
| 251 |
+
#include <ATen/ops/_lstm_mps.h>
|
| 252 |
+
#include <ATen/ops/_lu_with_info.h>
|
| 253 |
+
#include <ATen/ops/_make_dep_token.h>
|
| 254 |
+
#include <ATen/ops/_make_dual.h>
|
| 255 |
+
#include <ATen/ops/_make_dual_copy.h>
|
| 256 |
+
#include <ATen/ops/_make_per_channel_quantized_tensor.h>
|
| 257 |
+
#include <ATen/ops/_make_per_tensor_quantized_tensor.h>
|
| 258 |
+
#include <ATen/ops/_masked_scale.h>
|
| 259 |
+
#include <ATen/ops/_masked_softmax.h>
|
| 260 |
+
#include <ATen/ops/_masked_softmax_backward.h>
|
| 261 |
+
#include <ATen/ops/_mixed_dtypes_linear.h>
|
| 262 |
+
#include <ATen/ops/_mkldnn_reshape.h>
|
| 263 |
+
#include <ATen/ops/_mkldnn_transpose.h>
|
| 264 |
+
#include <ATen/ops/_mps_convolution.h>
|
| 265 |
+
#include <ATen/ops/_mps_convolution_transpose.h>
|
| 266 |
+
#include <ATen/ops/_native_batch_norm_legit.h>
|
| 267 |
+
#include <ATen/ops/_native_batch_norm_legit_no_training.h>
|
| 268 |
+
#include <ATen/ops/_native_multi_head_attention.h>
|
| 269 |
+
#include <ATen/ops/_neg_view.h>
|
| 270 |
+
#include <ATen/ops/_neg_view_copy.h>
|
| 271 |
+
#include <ATen/ops/_nested_from_padded.h>
|
| 272 |
+
#include <ATen/ops/_nested_from_padded_and_nested_example.h>
|
| 273 |
+
#include <ATen/ops/_nested_get_jagged_dummy.h>
|
| 274 |
+
#include <ATen/ops/_nested_get_lengths.h>
|
| 275 |
+
#include <ATen/ops/_nested_get_offsets.h>
|
| 276 |
+
#include <ATen/ops/_nested_get_ragged_idx.h>
|
| 277 |
+
#include <ATen/ops/_nested_get_values.h>
|
| 278 |
+
#include <ATen/ops/_nested_get_values_copy.h>
|
| 279 |
+
#include <ATen/ops/_nested_select_backward.h>
|
| 280 |
+
#include <ATen/ops/_nested_sum_backward.h>
|
| 281 |
+
#include <ATen/ops/_nested_tensor_from_mask.h>
|
| 282 |
+
#include <ATen/ops/_nested_tensor_from_mask_left_aligned.h>
|
| 283 |
+
#include <ATen/ops/_nested_tensor_from_tensor_list.h>
|
| 284 |
+
#include <ATen/ops/_nested_tensor_size.h>
|
| 285 |
+
#include <ATen/ops/_nested_tensor_softmax_with_shape.h>
|
| 286 |
+
#include <ATen/ops/_nested_tensor_storage_offsets.h>
|
| 287 |
+
#include <ATen/ops/_nested_tensor_strides.h>
|
| 288 |
+
#include <ATen/ops/_nested_view_from_buffer.h>
|
| 289 |
+
#include <ATen/ops/_nested_view_from_buffer_copy.h>
|
| 290 |
+
#include <ATen/ops/_nested_view_from_jagged.h>
|
| 291 |
+
#include <ATen/ops/_nested_view_from_jagged_copy.h>
|
| 292 |
+
#include <ATen/ops/_new_zeros_with_same_feature_meta.h>
|
| 293 |
+
#include <ATen/ops/_nnpack_available.h>
|
| 294 |
+
#include <ATen/ops/_nnpack_spatial_convolution.h>
|
| 295 |
+
#include <ATen/ops/_nnz.h>
|
| 296 |
+
#include <ATen/ops/_pack_padded_sequence.h>
|
| 297 |
+
#include <ATen/ops/_pack_padded_sequence_backward.h>
|
| 298 |
+
#include <ATen/ops/_pad_circular.h>
|
| 299 |
+
#include <ATen/ops/_pad_enum.h>
|
| 300 |
+
#include <ATen/ops/_pad_packed_sequence.h>
|
| 301 |
+
#include <ATen/ops/_pdist_backward.h>
|
| 302 |
+
#include <ATen/ops/_pdist_forward.h>
|
| 303 |
+
#include <ATen/ops/_pin_memory.h>
|
| 304 |
+
#include <ATen/ops/_prelu_kernel.h>
|
| 305 |
+
#include <ATen/ops/_prelu_kernel_backward.h>
|
| 306 |
+
#include <ATen/ops/_print.h>
|
| 307 |
+
#include <ATen/ops/_propagate_xla_data.h>
|
| 308 |
+
#include <ATen/ops/_remove_batch_dim.h>
|
| 309 |
+
#include <ATen/ops/_reshape_alias.h>
|
| 310 |
+
#include <ATen/ops/_reshape_alias_copy.h>
|
| 311 |
+
#include <ATen/ops/_reshape_copy.h>
|
| 312 |
+
#include <ATen/ops/_reshape_from_tensor.h>
|
| 313 |
+
#include <ATen/ops/_resize_output.h>
|
| 314 |
+
#include <ATen/ops/_rowwise_prune.h>
|
| 315 |
+
#include <ATen/ops/_sample_dirichlet.h>
|
| 316 |
+
#include <ATen/ops/_saturate_weight_to_fp16.h>
|
| 317 |
+
#include <ATen/ops/_scaled_dot_product_attention_math.h>
|
| 318 |
+
#include <ATen/ops/_scaled_dot_product_cudnn_attention.h>
|
| 319 |
+
#include <ATen/ops/_scaled_dot_product_efficient_attention.h>
|
| 320 |
+
#include <ATen/ops/_scaled_dot_product_efficient_attention_backward.h>
|
| 321 |
+
#include <ATen/ops/_scaled_dot_product_flash_attention.h>
|
| 322 |
+
#include <ATen/ops/_scaled_dot_product_flash_attention_backward.h>
|
| 323 |
+
#include <ATen/ops/_scaled_dot_product_flash_attention_for_cpu.h>
|
| 324 |
+
#include <ATen/ops/_scaled_dot_product_flash_attention_for_cpu_backward.h>
|
| 325 |
+
#include <ATen/ops/_scaled_mm.h>
|
| 326 |
+
#include <ATen/ops/_segment_reduce_backward.h>
|
| 327 |
+
#include <ATen/ops/_shape_as_tensor.h>
|
| 328 |
+
#include <ATen/ops/_slow_conv2d_backward.h>
|
| 329 |
+
#include <ATen/ops/_slow_conv2d_forward.h>
|
| 330 |
+
#include <ATen/ops/_sobol_engine_draw.h>
|
| 331 |
+
#include <ATen/ops/_sobol_engine_ff.h>
|
| 332 |
+
#include <ATen/ops/_sobol_engine_initialize_state.h>
|
| 333 |
+
#include <ATen/ops/_sobol_engine_scramble.h>
|
| 334 |
+
#include <ATen/ops/_softmax.h>
|
| 335 |
+
#include <ATen/ops/_softmax_backward_data.h>
|
| 336 |
+
#include <ATen/ops/_sparse_addmm.h>
|
| 337 |
+
#include <ATen/ops/_sparse_broadcast_to.h>
|
| 338 |
+
#include <ATen/ops/_sparse_broadcast_to_copy.h>
|
| 339 |
+
#include <ATen/ops/_sparse_bsc_tensor_unsafe.h>
|
| 340 |
+
#include <ATen/ops/_sparse_bsr_tensor_unsafe.h>
|
| 341 |
+
#include <ATen/ops/_sparse_compressed_tensor_unsafe.h>
|
| 342 |
+
#include <ATen/ops/_sparse_coo_tensor_unsafe.h>
|
| 343 |
+
#include <ATen/ops/_sparse_coo_tensor_with_dims.h>
|
| 344 |
+
#include <ATen/ops/_sparse_coo_tensor_with_dims_and_tensors.h>
|
| 345 |
+
#include <ATen/ops/_sparse_csc_tensor_unsafe.h>
|
| 346 |
+
#include <ATen/ops/_sparse_csr_prod.h>
|
| 347 |
+
#include <ATen/ops/_sparse_csr_sum.h>
|
| 348 |
+
#include <ATen/ops/_sparse_csr_tensor_unsafe.h>
|
| 349 |
+
#include <ATen/ops/_sparse_log_softmax.h>
|
| 350 |
+
#include <ATen/ops/_sparse_log_softmax_backward_data.h>
|
| 351 |
+
#include <ATen/ops/_sparse_mask_projection.h>
|
| 352 |
+
#include <ATen/ops/_sparse_mm.h>
|
| 353 |
+
#include <ATen/ops/_sparse_mm_reduce_impl.h>
|
| 354 |
+
#include <ATen/ops/_sparse_mm_reduce_impl_backward.h>
|
| 355 |
+
#include <ATen/ops/_sparse_semi_structured_linear.h>
|
| 356 |
+
#include <ATen/ops/_sparse_softmax.h>
|
| 357 |
+
#include <ATen/ops/_sparse_softmax_backward_data.h>
|
| 358 |
+
#include <ATen/ops/_sparse_sparse_matmul.h>
|
| 359 |
+
#include <ATen/ops/_sparse_sum.h>
|
| 360 |
+
#include <ATen/ops/_sparse_sum_backward.h>
|
| 361 |
+
#include <ATen/ops/_spdiags.h>
|
| 362 |
+
#include <ATen/ops/_stack.h>
|
| 363 |
+
#include <ATen/ops/_standard_gamma.h>
|
| 364 |
+
#include <ATen/ops/_standard_gamma_grad.h>
|
| 365 |
+
#include <ATen/ops/_test_ambiguous_defaults.h>
|
| 366 |
+
#include <ATen/ops/_test_autograd_multiple_dispatch.h>
|
| 367 |
+
#include <ATen/ops/_test_autograd_multiple_dispatch_view.h>
|
| 368 |
+
#include <ATen/ops/_test_autograd_multiple_dispatch_view_copy.h>
|
| 369 |
+
#include <ATen/ops/_test_check_tensor.h>
|
| 370 |
+
#include <ATen/ops/_test_functorch_fallback.h>
|
| 371 |
+
#include <ATen/ops/_test_optional_filled_intlist.h>
|
| 372 |
+
#include <ATen/ops/_test_optional_floatlist.h>
|
| 373 |
+
#include <ATen/ops/_test_optional_intlist.h>
|
| 374 |
+
#include <ATen/ops/_test_parallel_materialize.h>
|
| 375 |
+
#include <ATen/ops/_test_serialization_subcmul.h>
|
| 376 |
+
#include <ATen/ops/_test_string_default.h>
|
| 377 |
+
#include <ATen/ops/_test_warn_in_autograd.h>
|
| 378 |
+
#include <ATen/ops/_thnn_differentiable_gru_cell_backward.h>
|
| 379 |
+
#include <ATen/ops/_thnn_differentiable_lstm_cell_backward.h>
|
| 380 |
+
#include <ATen/ops/_thnn_fused_gru_cell.h>
|
| 381 |
+
#include <ATen/ops/_thnn_fused_gru_cell_backward.h>
|
| 382 |
+
#include <ATen/ops/_thnn_fused_lstm_cell.h>
|
| 383 |
+
#include <ATen/ops/_thnn_fused_lstm_cell_backward.h>
|
| 384 |
+
#include <ATen/ops/_thnn_fused_lstm_cell_backward_impl.h>
|
| 385 |
+
#include <ATen/ops/_to_copy.h>
|
| 386 |
+
#include <ATen/ops/_to_cpu.h>
|
| 387 |
+
#include <ATen/ops/_to_dense.h>
|
| 388 |
+
#include <ATen/ops/_to_sparse.h>
|
| 389 |
+
#include <ATen/ops/_to_sparse_bsc.h>
|
| 390 |
+
#include <ATen/ops/_to_sparse_bsr.h>
|
| 391 |
+
#include <ATen/ops/_to_sparse_csc.h>
|
| 392 |
+
#include <ATen/ops/_to_sparse_csr.h>
|
| 393 |
+
#include <ATen/ops/_to_sparse_semi_structured.h>
|
| 394 |
+
#include <ATen/ops/_transform_bias_rescale_qkv.h>
|
| 395 |
+
#include <ATen/ops/_transformer_encoder_layer_fwd.h>
|
| 396 |
+
#include <ATen/ops/_trilinear.h>
|
| 397 |
+
#include <ATen/ops/_triton_multi_head_attention.h>
|
| 398 |
+
#include <ATen/ops/_triton_scaled_dot_attention.h>
|
| 399 |
+
#include <ATen/ops/_unique.h>
|
| 400 |
+
#include <ATen/ops/_unique2.h>
|
| 401 |
+
#include <ATen/ops/_unpack_dual.h>
|
| 402 |
+
#include <ATen/ops/_unsafe_index.h>
|
| 403 |
+
#include <ATen/ops/_unsafe_index_put.h>
|
| 404 |
+
#include <ATen/ops/_unsafe_view.h>
|
| 405 |
+
#include <ATen/ops/_upsample_bicubic2d_aa.h>
|
| 406 |
+
#include <ATen/ops/_upsample_bicubic2d_aa_backward.h>
|
| 407 |
+
#include <ATen/ops/_upsample_bilinear2d_aa.h>
|
| 408 |
+
#include <ATen/ops/_upsample_bilinear2d_aa_backward.h>
|
| 409 |
+
#include <ATen/ops/_upsample_nearest_exact1d.h>
|
| 410 |
+
#include <ATen/ops/_upsample_nearest_exact1d_backward.h>
|
| 411 |
+
#include <ATen/ops/_upsample_nearest_exact2d.h>
|
| 412 |
+
#include <ATen/ops/_upsample_nearest_exact2d_backward.h>
|
| 413 |
+
#include <ATen/ops/_upsample_nearest_exact3d.h>
|
| 414 |
+
#include <ATen/ops/_upsample_nearest_exact3d_backward.h>
|
| 415 |
+
#include <ATen/ops/_use_cudnn_ctc_loss.h>
|
| 416 |
+
#include <ATen/ops/_use_cudnn_rnn_flatten_weight.h>
|
| 417 |
+
#include <ATen/ops/_validate_compressed_sparse_indices.h>
|
| 418 |
+
#include <ATen/ops/_validate_sparse_bsc_tensor_args.h>
|
| 419 |
+
#include <ATen/ops/_validate_sparse_bsr_tensor_args.h>
|
| 420 |
+
#include <ATen/ops/_validate_sparse_compressed_tensor_args.h>
|
| 421 |
+
#include <ATen/ops/_validate_sparse_coo_tensor_args.h>
|
| 422 |
+
#include <ATen/ops/_validate_sparse_csc_tensor_args.h>
|
| 423 |
+
#include <ATen/ops/_validate_sparse_csr_tensor_args.h>
|
| 424 |
+
#include <ATen/ops/_values.h>
|
| 425 |
+
#include <ATen/ops/_values_copy.h>
|
| 426 |
+
#include <ATen/ops/_version.h>
|
| 427 |
+
#include <ATen/ops/_weight_int4pack_mm.h>
|
| 428 |
+
#include <ATen/ops/_weight_int8pack_mm.h>
|
| 429 |
+
#include <ATen/ops/_weight_norm.h>
|
| 430 |
+
#include <ATen/ops/_weight_norm_differentiable_backward.h>
|
| 431 |
+
#include <ATen/ops/_weight_norm_interface.h>
|
| 432 |
+
#include <ATen/ops/_weight_norm_interface_backward.h>
|
| 433 |
+
#include <ATen/ops/abs.h>
|
| 434 |
+
#include <ATen/ops/absolute.h>
|
| 435 |
+
#include <ATen/ops/acos.h>
|
| 436 |
+
#include <ATen/ops/acosh.h>
|
| 437 |
+
#include <ATen/ops/adaptive_avg_pool1d.h>
|
| 438 |
+
#include <ATen/ops/adaptive_avg_pool2d.h>
|
| 439 |
+
#include <ATen/ops/adaptive_avg_pool3d.h>
|
| 440 |
+
#include <ATen/ops/adaptive_avg_pool3d_backward.h>
|
| 441 |
+
#include <ATen/ops/adaptive_max_pool1d.h>
|
| 442 |
+
#include <ATen/ops/adaptive_max_pool2d.h>
|
| 443 |
+
#include <ATen/ops/adaptive_max_pool2d_backward.h>
|
| 444 |
+
#include <ATen/ops/adaptive_max_pool3d.h>
|
| 445 |
+
#include <ATen/ops/adaptive_max_pool3d_backward.h>
|
| 446 |
+
#include <ATen/ops/add.h>
|
| 447 |
+
#include <ATen/ops/addbmm.h>
|
| 448 |
+
#include <ATen/ops/addcdiv.h>
|
| 449 |
+
#include <ATen/ops/addcmul.h>
|
| 450 |
+
#include <ATen/ops/addmm.h>
|
| 451 |
+
#include <ATen/ops/addmv.h>
|
| 452 |
+
#include <ATen/ops/addr.h>
|
| 453 |
+
#include <ATen/ops/adjoint.h>
|
| 454 |
+
#include <ATen/ops/affine_grid_generator.h>
|
| 455 |
+
#include <ATen/ops/affine_grid_generator_backward.h>
|
| 456 |
+
#include <ATen/ops/alias.h>
|
| 457 |
+
#include <ATen/ops/alias_copy.h>
|
| 458 |
+
#include <ATen/ops/align_as.h>
|
| 459 |
+
#include <ATen/ops/align_tensors.h>
|
| 460 |
+
#include <ATen/ops/align_to.h>
|
| 461 |
+
#include <ATen/ops/all.h>
|
| 462 |
+
#include <ATen/ops/allclose.h>
|
| 463 |
+
#include <ATen/ops/alpha_dropout.h>
|
| 464 |
+
#include <ATen/ops/amax.h>
|
| 465 |
+
#include <ATen/ops/amin.h>
|
| 466 |
+
#include <ATen/ops/aminmax.h>
|
| 467 |
+
#include <ATen/ops/and.h>
|
| 468 |
+
#include <ATen/ops/angle.h>
|
| 469 |
+
#include <ATen/ops/any.h>
|
| 470 |
+
#include <ATen/ops/arange.h>
|
| 471 |
+
#include <ATen/ops/arccos.h>
|
| 472 |
+
#include <ATen/ops/arccosh.h>
|
| 473 |
+
#include <ATen/ops/arcsin.h>
|
| 474 |
+
#include <ATen/ops/arcsinh.h>
|
| 475 |
+
#include <ATen/ops/arctan.h>
|
| 476 |
+
#include <ATen/ops/arctan2.h>
|
| 477 |
+
#include <ATen/ops/arctanh.h>
|
| 478 |
+
#include <ATen/ops/argmax.h>
|
| 479 |
+
#include <ATen/ops/argmin.h>
|
| 480 |
+
#include <ATen/ops/argsort.h>
|
| 481 |
+
#include <ATen/ops/argwhere.h>
|
| 482 |
+
#include <ATen/ops/as_strided.h>
|
| 483 |
+
#include <ATen/ops/as_strided_copy.h>
|
| 484 |
+
#include <ATen/ops/as_strided_scatter.h>
|
| 485 |
+
#include <ATen/ops/asin.h>
|
| 486 |
+
#include <ATen/ops/asinh.h>
|
| 487 |
+
#include <ATen/ops/atan.h>
|
| 488 |
+
#include <ATen/ops/atan2.h>
|
| 489 |
+
#include <ATen/ops/atanh.h>
|
| 490 |
+
#include <ATen/ops/atleast_1d.h>
|
| 491 |
+
#include <ATen/ops/atleast_2d.h>
|
| 492 |
+
#include <ATen/ops/atleast_3d.h>
|
| 493 |
+
#include <ATen/ops/avg_pool1d.h>
|
| 494 |
+
#include <ATen/ops/avg_pool2d.h>
|
| 495 |
+
#include <ATen/ops/avg_pool2d_backward.h>
|
| 496 |
+
#include <ATen/ops/avg_pool3d.h>
|
| 497 |
+
#include <ATen/ops/avg_pool3d_backward.h>
|
| 498 |
+
#include <ATen/ops/baddbmm.h>
|
| 499 |
+
#include <ATen/ops/bartlett_window.h>
|
| 500 |
+
#include <ATen/ops/batch_norm.h>
|
| 501 |
+
#include <ATen/ops/batch_norm_backward_elemt.h>
|
| 502 |
+
#include <ATen/ops/batch_norm_backward_reduce.h>
|
| 503 |
+
#include <ATen/ops/batch_norm_elemt.h>
|
| 504 |
+
#include <ATen/ops/batch_norm_gather_stats.h>
|
| 505 |
+
#include <ATen/ops/batch_norm_gather_stats_with_counts.h>
|
| 506 |
+
#include <ATen/ops/batch_norm_stats.h>
|
| 507 |
+
#include <ATen/ops/batch_norm_update_stats.h>
|
| 508 |
+
#include <ATen/ops/bernoulli.h>
|
| 509 |
+
#include <ATen/ops/bilinear.h>
|
| 510 |
+
#include <ATen/ops/binary_cross_entropy.h>
|
| 511 |
+
#include <ATen/ops/binary_cross_entropy_backward.h>
|
| 512 |
+
#include <ATen/ops/binary_cross_entropy_with_logits.h>
|
| 513 |
+
#include <ATen/ops/bincount.h>
|
| 514 |
+
#include <ATen/ops/binomial.h>
|
| 515 |
+
#include <ATen/ops/bitwise_and.h>
|
| 516 |
+
#include <ATen/ops/bitwise_left_shift.h>
|
| 517 |
+
#include <ATen/ops/bitwise_not.h>
|
| 518 |
+
#include <ATen/ops/bitwise_or.h>
|
| 519 |
+
#include <ATen/ops/bitwise_right_shift.h>
|
| 520 |
+
#include <ATen/ops/bitwise_xor.h>
|
| 521 |
+
#include <ATen/ops/blackman_window.h>
|
| 522 |
+
#include <ATen/ops/block_diag.h>
|
| 523 |
+
#include <ATen/ops/bmm.h>
|
| 524 |
+
#include <ATen/ops/broadcast_tensors.h>
|
| 525 |
+
#include <ATen/ops/broadcast_to.h>
|
| 526 |
+
#include <ATen/ops/bucketize.h>
|
| 527 |
+
#include <ATen/ops/can_cast.h>
|
| 528 |
+
#include <ATen/ops/cartesian_prod.h>
|
| 529 |
+
#include <ATen/ops/cat.h>
|
| 530 |
+
#include <ATen/ops/cauchy.h>
|
| 531 |
+
#include <ATen/ops/ccol_indices.h>
|
| 532 |
+
#include <ATen/ops/ccol_indices_copy.h>
|
| 533 |
+
#include <ATen/ops/cdist.h>
|
| 534 |
+
#include <ATen/ops/ceil.h>
|
| 535 |
+
#include <ATen/ops/celu.h>
|
| 536 |
+
#include <ATen/ops/chain_matmul.h>
|
| 537 |
+
#include <ATen/ops/chalf.h>
|
| 538 |
+
#include <ATen/ops/channel_shuffle.h>
|
| 539 |
+
#include <ATen/ops/cholesky.h>
|
| 540 |
+
#include <ATen/ops/cholesky_inverse.h>
|
| 541 |
+
#include <ATen/ops/cholesky_solve.h>
|
| 542 |
+
#include <ATen/ops/choose_qparams_optimized.h>
|
| 543 |
+
#include <ATen/ops/chunk.h>
|
| 544 |
+
#include <ATen/ops/clamp.h>
|
| 545 |
+
#include <ATen/ops/clamp_max.h>
|
| 546 |
+
#include <ATen/ops/clamp_min.h>
|
| 547 |
+
#include <ATen/ops/clip.h>
|
| 548 |
+
#include <ATen/ops/clone.h>
|
| 549 |
+
#include <ATen/ops/coalesce.h>
|
| 550 |
+
#include <ATen/ops/col2im.h>
|
| 551 |
+
#include <ATen/ops/col_indices.h>
|
| 552 |
+
#include <ATen/ops/col_indices_copy.h>
|
| 553 |
+
#include <ATen/ops/column_stack.h>
|
| 554 |
+
#include <ATen/ops/combinations.h>
|
| 555 |
+
#include <ATen/ops/complex.h>
|
| 556 |
+
#include <ATen/ops/concat.h>
|
| 557 |
+
#include <ATen/ops/concatenate.h>
|
| 558 |
+
#include <ATen/ops/conj.h>
|
| 559 |
+
#include <ATen/ops/conj_physical.h>
|
| 560 |
+
#include <ATen/ops/constant_pad_nd.h>
|
| 561 |
+
#include <ATen/ops/contiguous.h>
|
| 562 |
+
#include <ATen/ops/conv1d.h>
|
| 563 |
+
#include <ATen/ops/conv2d.h>
|
| 564 |
+
#include <ATen/ops/conv3d.h>
|
| 565 |
+
#include <ATen/ops/conv_depthwise3d.h>
|
| 566 |
+
#include <ATen/ops/conv_tbc.h>
|
| 567 |
+
#include <ATen/ops/conv_tbc_backward.h>
|
| 568 |
+
#include <ATen/ops/conv_transpose1d.h>
|
| 569 |
+
#include <ATen/ops/conv_transpose2d.h>
|
| 570 |
+
#include <ATen/ops/conv_transpose3d.h>
|
| 571 |
+
#include <ATen/ops/convolution.h>
|
| 572 |
+
#include <ATen/ops/convolution_backward.h>
|
| 573 |
+
#include <ATen/ops/convolution_backward_overrideable.h>
|
| 574 |
+
#include <ATen/ops/convolution_overrideable.h>
|
| 575 |
+
#include <ATen/ops/copy.h>
|
| 576 |
+
#include <ATen/ops/copy_sparse_to_sparse.h>
|
| 577 |
+
#include <ATen/ops/copysign.h>
|
| 578 |
+
#include <ATen/ops/corrcoef.h>
|
| 579 |
+
#include <ATen/ops/cos.h>
|
| 580 |
+
#include <ATen/ops/cosh.h>
|
| 581 |
+
#include <ATen/ops/cosine_embedding_loss.h>
|
| 582 |
+
#include <ATen/ops/cosine_similarity.h>
|
| 583 |
+
#include <ATen/ops/count_nonzero.h>
|
| 584 |
+
#include <ATen/ops/cov.h>
|
| 585 |
+
#include <ATen/ops/cross.h>
|
| 586 |
+
#include <ATen/ops/cross_entropy_loss.h>
|
| 587 |
+
#include <ATen/ops/crow_indices.h>
|
| 588 |
+
#include <ATen/ops/crow_indices_copy.h>
|
| 589 |
+
#include <ATen/ops/ctc_loss.h>
|
| 590 |
+
#include <ATen/ops/cudnn_affine_grid_generator.h>
|
| 591 |
+
#include <ATen/ops/cudnn_affine_grid_generator_backward.h>
|
| 592 |
+
#include <ATen/ops/cudnn_batch_norm.h>
|
| 593 |
+
#include <ATen/ops/cudnn_batch_norm_backward.h>
|
| 594 |
+
#include <ATen/ops/cudnn_convolution.h>
|
| 595 |
+
#include <ATen/ops/cudnn_convolution_add_relu.h>
|
| 596 |
+
#include <ATen/ops/cudnn_convolution_relu.h>
|
| 597 |
+
#include <ATen/ops/cudnn_convolution_transpose.h>
|
| 598 |
+
#include <ATen/ops/cudnn_grid_sampler.h>
|
| 599 |
+
#include <ATen/ops/cudnn_grid_sampler_backward.h>
|
| 600 |
+
#include <ATen/ops/cudnn_is_acceptable.h>
|
| 601 |
+
#include <ATen/ops/cummax.h>
|
| 602 |
+
#include <ATen/ops/cummaxmin_backward.h>
|
| 603 |
+
#include <ATen/ops/cummin.h>
|
| 604 |
+
#include <ATen/ops/cumprod.h>
|
| 605 |
+
#include <ATen/ops/cumprod_backward.h>
|
| 606 |
+
#include <ATen/ops/cumsum.h>
|
| 607 |
+
#include <ATen/ops/cumulative_trapezoid.h>
|
| 608 |
+
#include <ATen/ops/data.h>
|
| 609 |
+
#include <ATen/ops/deg2rad.h>
|
| 610 |
+
#include <ATen/ops/dense_dim.h>
|
| 611 |
+
#include <ATen/ops/dequantize.h>
|
| 612 |
+
#include <ATen/ops/det.h>
|
| 613 |
+
#include <ATen/ops/detach.h>
|
| 614 |
+
#include <ATen/ops/detach_copy.h>
|
| 615 |
+
#include <ATen/ops/diag.h>
|
| 616 |
+
#include <ATen/ops/diag_embed.h>
|
| 617 |
+
#include <ATen/ops/diagflat.h>
|
| 618 |
+
#include <ATen/ops/diagonal.h>
|
| 619 |
+
#include <ATen/ops/diagonal_backward.h>
|
| 620 |
+
#include <ATen/ops/diagonal_copy.h>
|
| 621 |
+
#include <ATen/ops/diagonal_scatter.h>
|
| 622 |
+
#include <ATen/ops/diff.h>
|
| 623 |
+
#include <ATen/ops/digamma.h>
|
| 624 |
+
#include <ATen/ops/dist.h>
|
| 625 |
+
#include <ATen/ops/div.h>
|
| 626 |
+
#include <ATen/ops/divide.h>
|
| 627 |
+
#include <ATen/ops/dot.h>
|
| 628 |
+
#include <ATen/ops/dropout.h>
|
| 629 |
+
#include <ATen/ops/dsplit.h>
|
| 630 |
+
#include <ATen/ops/dstack.h>
|
| 631 |
+
#include <ATen/ops/einsum.h>
|
| 632 |
+
#include <ATen/ops/elu.h>
|
| 633 |
+
#include <ATen/ops/elu_backward.h>
|
| 634 |
+
#include <ATen/ops/embedding.h>
|
| 635 |
+
#include <ATen/ops/embedding_backward.h>
|
| 636 |
+
#include <ATen/ops/embedding_bag.h>
|
| 637 |
+
#include <ATen/ops/embedding_dense_backward.h>
|
| 638 |
+
#include <ATen/ops/embedding_renorm.h>
|
| 639 |
+
#include <ATen/ops/embedding_sparse_backward.h>
|
| 640 |
+
#include <ATen/ops/empty.h>
|
| 641 |
+
#include <ATen/ops/empty_like.h>
|
| 642 |
+
#include <ATen/ops/empty_permuted.h>
|
| 643 |
+
#include <ATen/ops/empty_quantized.h>
|
| 644 |
+
#include <ATen/ops/empty_strided.h>
|
| 645 |
+
#include <ATen/ops/eq.h>
|
| 646 |
+
#include <ATen/ops/equal.h>
|
| 647 |
+
#include <ATen/ops/erf.h>
|
| 648 |
+
#include <ATen/ops/erfc.h>
|
| 649 |
+
#include <ATen/ops/erfinv.h>
|
| 650 |
+
#include <ATen/ops/exp.h>
|
| 651 |
+
#include <ATen/ops/exp2.h>
|
| 652 |
+
#include <ATen/ops/expand.h>
|
| 653 |
+
#include <ATen/ops/expand_as.h>
|
| 654 |
+
#include <ATen/ops/expand_copy.h>
|
| 655 |
+
#include <ATen/ops/expm1.h>
|
| 656 |
+
#include <ATen/ops/exponential.h>
|
| 657 |
+
#include <ATen/ops/eye.h>
|
| 658 |
+
#include <ATen/ops/fake_quantize_per_channel_affine.h>
|
| 659 |
+
#include <ATen/ops/fake_quantize_per_channel_affine_cachemask.h>
|
| 660 |
+
#include <ATen/ops/fake_quantize_per_channel_affine_cachemask_backward.h>
|
| 661 |
+
#include <ATen/ops/fake_quantize_per_tensor_affine.h>
|
| 662 |
+
#include <ATen/ops/fake_quantize_per_tensor_affine_cachemask.h>
|
| 663 |
+
#include <ATen/ops/fake_quantize_per_tensor_affine_cachemask_backward.h>
|
| 664 |
+
#include <ATen/ops/fbgemm_linear_fp16_weight.h>
|
| 665 |
+
#include <ATen/ops/fbgemm_linear_fp16_weight_fp32_activation.h>
|
| 666 |
+
#include <ATen/ops/fbgemm_linear_int8_weight.h>
|
| 667 |
+
#include <ATen/ops/fbgemm_linear_int8_weight_fp32_activation.h>
|
| 668 |
+
#include <ATen/ops/fbgemm_linear_quantize_weight.h>
|
| 669 |
+
#include <ATen/ops/fbgemm_pack_gemm_matrix_fp16.h>
|
| 670 |
+
#include <ATen/ops/fbgemm_pack_quantized_matrix.h>
|
| 671 |
+
#include <ATen/ops/feature_alpha_dropout.h>
|
| 672 |
+
#include <ATen/ops/feature_dropout.h>
|
| 673 |
+
#include <ATen/ops/fft_fft.h>
|
| 674 |
+
#include <ATen/ops/fft_fft2.h>
|
| 675 |
+
#include <ATen/ops/fft_fftfreq.h>
|
| 676 |
+
#include <ATen/ops/fft_fftn.h>
|
| 677 |
+
#include <ATen/ops/fft_fftshift.h>
|
| 678 |
+
#include <ATen/ops/fft_hfft.h>
|
| 679 |
+
#include <ATen/ops/fft_hfft2.h>
|
| 680 |
+
#include <ATen/ops/fft_hfftn.h>
|
| 681 |
+
#include <ATen/ops/fft_ifft.h>
|
| 682 |
+
#include <ATen/ops/fft_ifft2.h>
|
| 683 |
+
#include <ATen/ops/fft_ifftn.h>
|
| 684 |
+
#include <ATen/ops/fft_ifftshift.h>
|
| 685 |
+
#include <ATen/ops/fft_ihfft.h>
|
| 686 |
+
#include <ATen/ops/fft_ihfft2.h>
|
| 687 |
+
#include <ATen/ops/fft_ihfftn.h>
|
| 688 |
+
#include <ATen/ops/fft_irfft.h>
|
| 689 |
+
#include <ATen/ops/fft_irfft2.h>
|
| 690 |
+
#include <ATen/ops/fft_irfftn.h>
|
| 691 |
+
#include <ATen/ops/fft_rfft.h>
|
| 692 |
+
#include <ATen/ops/fft_rfft2.h>
|
| 693 |
+
#include <ATen/ops/fft_rfftfreq.h>
|
| 694 |
+
#include <ATen/ops/fft_rfftn.h>
|
| 695 |
+
#include <ATen/ops/fill.h>
|
| 696 |
+
#include <ATen/ops/fill_diagonal.h>
|
| 697 |
+
#include <ATen/ops/fix.h>
|
| 698 |
+
#include <ATen/ops/flatten.h>
|
| 699 |
+
#include <ATen/ops/flatten_dense_tensors.h>
|
| 700 |
+
#include <ATen/ops/flip.h>
|
| 701 |
+
#include <ATen/ops/fliplr.h>
|
| 702 |
+
#include <ATen/ops/flipud.h>
|
| 703 |
+
#include <ATen/ops/float_power.h>
|
| 704 |
+
#include <ATen/ops/floor.h>
|
| 705 |
+
#include <ATen/ops/floor_divide.h>
|
| 706 |
+
#include <ATen/ops/fmax.h>
|
| 707 |
+
#include <ATen/ops/fmin.h>
|
| 708 |
+
#include <ATen/ops/fmod.h>
|
| 709 |
+
#include <ATen/ops/frac.h>
|
| 710 |
+
#include <ATen/ops/fractional_max_pool2d.h>
|
| 711 |
+
#include <ATen/ops/fractional_max_pool2d_backward.h>
|
| 712 |
+
#include <ATen/ops/fractional_max_pool3d.h>
|
| 713 |
+
#include <ATen/ops/fractional_max_pool3d_backward.h>
|
| 714 |
+
#include <ATen/ops/frexp.h>
|
| 715 |
+
#include <ATen/ops/frobenius_norm.h>
|
| 716 |
+
#include <ATen/ops/from_file.h>
|
| 717 |
+
#include <ATen/ops/full.h>
|
| 718 |
+
#include <ATen/ops/full_like.h>
|
| 719 |
+
#include <ATen/ops/fused_moving_avg_obs_fake_quant.h>
|
| 720 |
+
#include <ATen/ops/gather.h>
|
| 721 |
+
#include <ATen/ops/gather_backward.h>
|
| 722 |
+
#include <ATen/ops/gcd.h>
|
| 723 |
+
#include <ATen/ops/ge.h>
|
| 724 |
+
#include <ATen/ops/gelu.h>
|
| 725 |
+
#include <ATen/ops/gelu_backward.h>
|
| 726 |
+
#include <ATen/ops/geometric.h>
|
| 727 |
+
#include <ATen/ops/geqrf.h>
|
| 728 |
+
#include <ATen/ops/ger.h>
|
| 729 |
+
#include <ATen/ops/glu.h>
|
| 730 |
+
#include <ATen/ops/glu_backward.h>
|
| 731 |
+
#include <ATen/ops/glu_backward_jvp.h>
|
| 732 |
+
#include <ATen/ops/glu_jvp.h>
|
| 733 |
+
#include <ATen/ops/gradient.h>
|
| 734 |
+
#include <ATen/ops/greater.h>
|
| 735 |
+
#include <ATen/ops/greater_equal.h>
|
| 736 |
+
#include <ATen/ops/grid_sampler.h>
|
| 737 |
+
#include <ATen/ops/grid_sampler_2d.h>
|
| 738 |
+
#include <ATen/ops/grid_sampler_2d_backward.h>
|
| 739 |
+
#include <ATen/ops/grid_sampler_3d.h>
|
| 740 |
+
#include <ATen/ops/grid_sampler_3d_backward.h>
|
| 741 |
+
#include <ATen/ops/group_norm.h>
|
| 742 |
+
#include <ATen/ops/gru.h>
|
| 743 |
+
#include <ATen/ops/gru_cell.h>
|
| 744 |
+
#include <ATen/ops/gt.h>
|
| 745 |
+
#include <ATen/ops/hamming_window.h>
|
| 746 |
+
#include <ATen/ops/hann_window.h>
|
| 747 |
+
#include <ATen/ops/hardshrink.h>
|
| 748 |
+
#include <ATen/ops/hardshrink_backward.h>
|
| 749 |
+
#include <ATen/ops/hardsigmoid.h>
|
| 750 |
+
#include <ATen/ops/hardsigmoid_backward.h>
|
| 751 |
+
#include <ATen/ops/hardswish.h>
|
| 752 |
+
#include <ATen/ops/hardswish_backward.h>
|
| 753 |
+
#include <ATen/ops/hardtanh.h>
|
| 754 |
+
#include <ATen/ops/hardtanh_backward.h>
|
| 755 |
+
#include <ATen/ops/heaviside.h>
|
| 756 |
+
#include <ATen/ops/hinge_embedding_loss.h>
|
| 757 |
+
#include <ATen/ops/histc.h>
|
| 758 |
+
#include <ATen/ops/histogram.h>
|
| 759 |
+
#include <ATen/ops/histogramdd.h>
|
| 760 |
+
#include <ATen/ops/hsplit.h>
|
| 761 |
+
#include <ATen/ops/hspmm.h>
|
| 762 |
+
#include <ATen/ops/hstack.h>
|
| 763 |
+
#include <ATen/ops/huber_loss.h>
|
| 764 |
+
#include <ATen/ops/huber_loss_backward.h>
|
| 765 |
+
#include <ATen/ops/hypot.h>
|
| 766 |
+
#include <ATen/ops/i0.h>
|
| 767 |
+
#include <ATen/ops/igamma.h>
|
| 768 |
+
#include <ATen/ops/igammac.h>
|
| 769 |
+
#include <ATen/ops/im2col.h>
|
| 770 |
+
#include <ATen/ops/imag.h>
|
| 771 |
+
#include <ATen/ops/index.h>
|
| 772 |
+
#include <ATen/ops/index_add.h>
|
| 773 |
+
#include <ATen/ops/index_copy.h>
|
| 774 |
+
#include <ATen/ops/index_fill.h>
|
| 775 |
+
#include <ATen/ops/index_put.h>
|
| 776 |
+
#include <ATen/ops/index_reduce.h>
|
| 777 |
+
#include <ATen/ops/index_select.h>
|
| 778 |
+
#include <ATen/ops/index_select_backward.h>
|
| 779 |
+
#include <ATen/ops/indices.h>
|
| 780 |
+
#include <ATen/ops/indices_copy.h>
|
| 781 |
+
#include <ATen/ops/infinitely_differentiable_gelu_backward.h>
|
| 782 |
+
#include <ATen/ops/inner.h>
|
| 783 |
+
#include <ATen/ops/instance_norm.h>
|
| 784 |
+
#include <ATen/ops/int_repr.h>
|
| 785 |
+
#include <ATen/ops/inverse.h>
|
| 786 |
+
#include <ATen/ops/is_coalesced.h>
|
| 787 |
+
#include <ATen/ops/is_complex.h>
|
| 788 |
+
#include <ATen/ops/is_conj.h>
|
| 789 |
+
#include <ATen/ops/is_distributed.h>
|
| 790 |
+
#include <ATen/ops/is_floating_point.h>
|
| 791 |
+
#include <ATen/ops/is_inference.h>
|
| 792 |
+
#include <ATen/ops/is_leaf.h>
|
| 793 |
+
#include <ATen/ops/is_neg.h>
|
| 794 |
+
#include <ATen/ops/is_nonzero.h>
|
| 795 |
+
#include <ATen/ops/is_pinned.h>
|
| 796 |
+
#include <ATen/ops/is_same_size.h>
|
| 797 |
+
#include <ATen/ops/is_set_to.h>
|
| 798 |
+
#include <ATen/ops/is_signed.h>
|
| 799 |
+
#include <ATen/ops/is_vulkan_available.h>
|
| 800 |
+
#include <ATen/ops/isclose.h>
|
| 801 |
+
#include <ATen/ops/isfinite.h>
|
| 802 |
+
#include <ATen/ops/isin.h>
|
| 803 |
+
#include <ATen/ops/isinf.h>
|
| 804 |
+
#include <ATen/ops/isnan.h>
|
| 805 |
+
#include <ATen/ops/isneginf.h>
|
| 806 |
+
#include <ATen/ops/isposinf.h>
|
| 807 |
+
#include <ATen/ops/isreal.h>
|
| 808 |
+
#include <ATen/ops/istft.h>
|
| 809 |
+
#include <ATen/ops/item.h>
|
| 810 |
+
#include <ATen/ops/kaiser_window.h>
|
| 811 |
+
#include <ATen/ops/kl_div.h>
|
| 812 |
+
#include <ATen/ops/kron.h>
|
| 813 |
+
#include <ATen/ops/kthvalue.h>
|
| 814 |
+
#include <ATen/ops/l1_loss.h>
|
| 815 |
+
#include <ATen/ops/layer_norm.h>
|
| 816 |
+
#include <ATen/ops/lcm.h>
|
| 817 |
+
#include <ATen/ops/ldexp.h>
|
| 818 |
+
#include <ATen/ops/le.h>
|
| 819 |
+
#include <ATen/ops/leaky_relu.h>
|
| 820 |
+
#include <ATen/ops/leaky_relu_backward.h>
|
| 821 |
+
#include <ATen/ops/lerp.h>
|
| 822 |
+
#include <ATen/ops/less.h>
|
| 823 |
+
#include <ATen/ops/less_equal.h>
|
| 824 |
+
#include <ATen/ops/lgamma.h>
|
| 825 |
+
#include <ATen/ops/lift.h>
|
| 826 |
+
#include <ATen/ops/lift_fresh.h>
|
| 827 |
+
#include <ATen/ops/lift_fresh_copy.h>
|
| 828 |
+
#include <ATen/ops/linalg_cholesky.h>
|
| 829 |
+
#include <ATen/ops/linalg_cholesky_ex.h>
|
| 830 |
+
#include <ATen/ops/linalg_cond.h>
|
| 831 |
+
#include <ATen/ops/linalg_cross.h>
|
| 832 |
+
#include <ATen/ops/linalg_det.h>
|
| 833 |
+
#include <ATen/ops/linalg_diagonal.h>
|
| 834 |
+
#include <ATen/ops/linalg_eig.h>
|
| 835 |
+
#include <ATen/ops/linalg_eigh.h>
|
| 836 |
+
#include <ATen/ops/linalg_eigvals.h>
|
| 837 |
+
#include <ATen/ops/linalg_eigvalsh.h>
|
| 838 |
+
#include <ATen/ops/linalg_householder_product.h>
|
| 839 |
+
#include <ATen/ops/linalg_inv.h>
|
| 840 |
+
#include <ATen/ops/linalg_inv_ex.h>
|
| 841 |
+
#include <ATen/ops/linalg_ldl_factor.h>
|
| 842 |
+
#include <ATen/ops/linalg_ldl_factor_ex.h>
|
| 843 |
+
#include <ATen/ops/linalg_ldl_solve.h>
|
| 844 |
+
#include <ATen/ops/linalg_lstsq.h>
|
| 845 |
+
#include <ATen/ops/linalg_lu.h>
|
| 846 |
+
#include <ATen/ops/linalg_lu_factor.h>
|
| 847 |
+
#include <ATen/ops/linalg_lu_factor_ex.h>
|
| 848 |
+
#include <ATen/ops/linalg_lu_solve.h>
|
| 849 |
+
#include <ATen/ops/linalg_matmul.h>
|
| 850 |
+
#include <ATen/ops/linalg_matrix_exp.h>
|
| 851 |
+
#include <ATen/ops/linalg_matrix_norm.h>
|
| 852 |
+
#include <ATen/ops/linalg_matrix_power.h>
|
| 853 |
+
#include <ATen/ops/linalg_matrix_rank.h>
|
| 854 |
+
#include <ATen/ops/linalg_multi_dot.h>
|
| 855 |
+
#include <ATen/ops/linalg_norm.h>
|
| 856 |
+
#include <ATen/ops/linalg_pinv.h>
|
| 857 |
+
#include <ATen/ops/linalg_qr.h>
|
| 858 |
+
#include <ATen/ops/linalg_slogdet.h>
|
| 859 |
+
#include <ATen/ops/linalg_solve.h>
|
| 860 |
+
#include <ATen/ops/linalg_solve_ex.h>
|
| 861 |
+
#include <ATen/ops/linalg_solve_triangular.h>
|
| 862 |
+
#include <ATen/ops/linalg_svd.h>
|
| 863 |
+
#include <ATen/ops/linalg_svdvals.h>
|
| 864 |
+
#include <ATen/ops/linalg_tensorinv.h>
|
| 865 |
+
#include <ATen/ops/linalg_tensorsolve.h>
|
| 866 |
+
#include <ATen/ops/linalg_vander.h>
|
| 867 |
+
#include <ATen/ops/linalg_vecdot.h>
|
| 868 |
+
#include <ATen/ops/linalg_vector_norm.h>
|
| 869 |
+
#include <ATen/ops/linear.h>
|
| 870 |
+
#include <ATen/ops/linear_backward.h>
|
| 871 |
+
#include <ATen/ops/linspace.h>
|
| 872 |
+
#include <ATen/ops/log.h>
|
| 873 |
+
#include <ATen/ops/log10.h>
|
| 874 |
+
#include <ATen/ops/log1p.h>
|
| 875 |
+
#include <ATen/ops/log2.h>
|
| 876 |
+
#include <ATen/ops/log_normal.h>
|
| 877 |
+
#include <ATen/ops/log_sigmoid.h>
|
| 878 |
+
#include <ATen/ops/log_sigmoid_backward.h>
|
| 879 |
+
#include <ATen/ops/log_sigmoid_forward.h>
|
| 880 |
+
#include <ATen/ops/log_softmax.h>
|
| 881 |
+
#include <ATen/ops/logaddexp.h>
|
| 882 |
+
#include <ATen/ops/logaddexp2.h>
|
| 883 |
+
#include <ATen/ops/logcumsumexp.h>
|
| 884 |
+
#include <ATen/ops/logdet.h>
|
| 885 |
+
#include <ATen/ops/logical_and.h>
|
| 886 |
+
#include <ATen/ops/logical_not.h>
|
| 887 |
+
#include <ATen/ops/logical_or.h>
|
| 888 |
+
#include <ATen/ops/logical_xor.h>
|
| 889 |
+
#include <ATen/ops/logit.h>
|
| 890 |
+
#include <ATen/ops/logit_backward.h>
|
| 891 |
+
#include <ATen/ops/logspace.h>
|
| 892 |
+
#include <ATen/ops/logsumexp.h>
|
| 893 |
+
#include <ATen/ops/lshift.h>
|
| 894 |
+
#include <ATen/ops/lstm.h>
|
| 895 |
+
#include <ATen/ops/lstm_cell.h>
|
| 896 |
+
#include <ATen/ops/lstm_mps_backward.h>
|
| 897 |
+
#include <ATen/ops/lt.h>
|
| 898 |
+
#include <ATen/ops/lu_solve.h>
|
| 899 |
+
#include <ATen/ops/lu_unpack.h>
|
| 900 |
+
#include <ATen/ops/mH.h>
|
| 901 |
+
#include <ATen/ops/mT.h>
|
| 902 |
+
#include <ATen/ops/margin_ranking_loss.h>
|
| 903 |
+
#include <ATen/ops/masked_fill.h>
|
| 904 |
+
#include <ATen/ops/masked_scatter.h>
|
| 905 |
+
#include <ATen/ops/masked_scatter_backward.h>
|
| 906 |
+
#include <ATen/ops/masked_select.h>
|
| 907 |
+
#include <ATen/ops/masked_select_backward.h>
|
| 908 |
+
#include <ATen/ops/matmul.h>
|
| 909 |
+
#include <ATen/ops/matmul_backward.h>
|
| 910 |
+
#include <ATen/ops/matrix_H.h>
|
| 911 |
+
#include <ATen/ops/matrix_exp.h>
|
| 912 |
+
#include <ATen/ops/matrix_exp_backward.h>
|
| 913 |
+
#include <ATen/ops/matrix_power.h>
|
| 914 |
+
#include <ATen/ops/max.h>
|
| 915 |
+
#include <ATen/ops/max_pool1d.h>
|
| 916 |
+
#include <ATen/ops/max_pool1d_with_indices.h>
|
| 917 |
+
#include <ATen/ops/max_pool2d.h>
|
| 918 |
+
#include <ATen/ops/max_pool2d_backward.h>
|
| 919 |
+
#include <ATen/ops/max_pool2d_with_indices.h>
|
| 920 |
+
#include <ATen/ops/max_pool2d_with_indices_backward.h>
|
| 921 |
+
#include <ATen/ops/max_pool3d.h>
|
| 922 |
+
#include <ATen/ops/max_pool3d_with_indices.h>
|
| 923 |
+
#include <ATen/ops/max_pool3d_with_indices_backward.h>
|
| 924 |
+
#include <ATen/ops/max_unpool2d.h>
|
| 925 |
+
#include <ATen/ops/max_unpool3d.h>
|
| 926 |
+
#include <ATen/ops/maximum.h>
|
| 927 |
+
#include <ATen/ops/mean.h>
|
| 928 |
+
#include <ATen/ops/median.h>
|
| 929 |
+
#include <ATen/ops/meshgrid.h>
|
| 930 |
+
#include <ATen/ops/min.h>
|
| 931 |
+
#include <ATen/ops/minimum.h>
|
| 932 |
+
#include <ATen/ops/miopen_batch_norm.h>
|
| 933 |
+
#include <ATen/ops/miopen_batch_norm_backward.h>
|
| 934 |
+
#include <ATen/ops/miopen_convolution.h>
|
| 935 |
+
#include <ATen/ops/miopen_convolution_add_relu.h>
|
| 936 |
+
#include <ATen/ops/miopen_convolution_relu.h>
|
| 937 |
+
#include <ATen/ops/miopen_convolution_transpose.h>
|
| 938 |
+
#include <ATen/ops/miopen_depthwise_convolution.h>
|
| 939 |
+
#include <ATen/ops/miopen_rnn.h>
|
| 940 |
+
#include <ATen/ops/miopen_rnn_backward.h>
|
| 941 |
+
#include <ATen/ops/mish.h>
|
| 942 |
+
#include <ATen/ops/mish_backward.h>
|
| 943 |
+
#include <ATen/ops/mkldnn_adaptive_avg_pool2d.h>
|
| 944 |
+
#include <ATen/ops/mkldnn_adaptive_avg_pool2d_backward.h>
|
| 945 |
+
#include <ATen/ops/mkldnn_convolution.h>
|
| 946 |
+
#include <ATen/ops/mkldnn_linear.h>
|
| 947 |
+
#include <ATen/ops/mkldnn_linear_backward.h>
|
| 948 |
+
#include <ATen/ops/mkldnn_linear_backward_input.h>
|
| 949 |
+
#include <ATen/ops/mkldnn_linear_backward_weights.h>
|
| 950 |
+
#include <ATen/ops/mkldnn_max_pool2d.h>
|
| 951 |
+
#include <ATen/ops/mkldnn_max_pool2d_backward.h>
|
| 952 |
+
#include <ATen/ops/mkldnn_max_pool3d.h>
|
| 953 |
+
#include <ATen/ops/mkldnn_max_pool3d_backward.h>
|
| 954 |
+
#include <ATen/ops/mkldnn_reorder_conv2d_weight.h>
|
| 955 |
+
#include <ATen/ops/mkldnn_reorder_conv3d_weight.h>
|
| 956 |
+
#include <ATen/ops/mkldnn_rnn_layer.h>
|
| 957 |
+
#include <ATen/ops/mkldnn_rnn_layer_backward.h>
|
| 958 |
+
#include <ATen/ops/mm.h>
|
| 959 |
+
#include <ATen/ops/mode.h>
|
| 960 |
+
#include <ATen/ops/moveaxis.h>
|
| 961 |
+
#include <ATen/ops/movedim.h>
|
| 962 |
+
#include <ATen/ops/mps_convolution_backward.h>
|
| 963 |
+
#include <ATen/ops/mps_convolution_transpose_backward.h>
|
| 964 |
+
#include <ATen/ops/mse_loss.h>
|
| 965 |
+
#include <ATen/ops/mse_loss_backward.h>
|
| 966 |
+
#include <ATen/ops/msort.h>
|
| 967 |
+
#include <ATen/ops/mul.h>
|
| 968 |
+
#include <ATen/ops/multi_margin_loss.h>
|
| 969 |
+
#include <ATen/ops/multi_margin_loss_backward.h>
|
| 970 |
+
#include <ATen/ops/multilabel_margin_loss.h>
|
| 971 |
+
#include <ATen/ops/multilabel_margin_loss_backward.h>
|
| 972 |
+
#include <ATen/ops/multilabel_margin_loss_forward.h>
|
| 973 |
+
#include <ATen/ops/multinomial.h>
|
| 974 |
+
#include <ATen/ops/multiply.h>
|
| 975 |
+
#include <ATen/ops/mv.h>
|
| 976 |
+
#include <ATen/ops/mvlgamma.h>
|
| 977 |
+
#include <ATen/ops/nan_to_num.h>
|
| 978 |
+
#include <ATen/ops/nanmean.h>
|
| 979 |
+
#include <ATen/ops/nanmedian.h>
|
| 980 |
+
#include <ATen/ops/nanquantile.h>
|
| 981 |
+
#include <ATen/ops/nansum.h>
|
| 982 |
+
#include <ATen/ops/narrow.h>
|
| 983 |
+
#include <ATen/ops/narrow_copy.h>
|
| 984 |
+
#include <ATen/ops/native_batch_norm.h>
|
| 985 |
+
#include <ATen/ops/native_batch_norm_backward.h>
|
| 986 |
+
#include <ATen/ops/native_channel_shuffle.h>
|
| 987 |
+
#include <ATen/ops/native_dropout.h>
|
| 988 |
+
#include <ATen/ops/native_dropout_backward.h>
|
| 989 |
+
#include <ATen/ops/native_group_norm.h>
|
| 990 |
+
#include <ATen/ops/native_group_norm_backward.h>
|
| 991 |
+
#include <ATen/ops/native_layer_norm.h>
|
| 992 |
+
#include <ATen/ops/native_layer_norm_backward.h>
|
| 993 |
+
#include <ATen/ops/native_norm.h>
|
| 994 |
+
#include <ATen/ops/ne.h>
|
| 995 |
+
#include <ATen/ops/neg.h>
|
| 996 |
+
#include <ATen/ops/negative.h>
|
| 997 |
+
#include <ATen/ops/nested_to_padded_tensor.h>
|
| 998 |
+
#include <ATen/ops/new_empty.h>
|
| 999 |
+
#include <ATen/ops/new_empty_strided.h>
|
| 1000 |
+
#include <ATen/ops/new_full.h>
|
| 1001 |
+
#include <ATen/ops/new_ones.h>
|
| 1002 |
+
#include <ATen/ops/new_zeros.h>
|
| 1003 |
+
#include <ATen/ops/nextafter.h>
|
| 1004 |
+
#include <ATen/ops/nll_loss.h>
|
| 1005 |
+
#include <ATen/ops/nll_loss2d.h>
|
| 1006 |
+
#include <ATen/ops/nll_loss2d_backward.h>
|
| 1007 |
+
#include <ATen/ops/nll_loss2d_forward.h>
|
| 1008 |
+
#include <ATen/ops/nll_loss_backward.h>
|
| 1009 |
+
#include <ATen/ops/nll_loss_forward.h>
|
| 1010 |
+
#include <ATen/ops/nll_loss_nd.h>
|
| 1011 |
+
#include <ATen/ops/nonzero.h>
|
| 1012 |
+
#include <ATen/ops/nonzero_numpy.h>
|
| 1013 |
+
#include <ATen/ops/nonzero_static.h>
|
| 1014 |
+
#include <ATen/ops/norm.h>
|
| 1015 |
+
#include <ATen/ops/norm_except_dim.h>
|
| 1016 |
+
#include <ATen/ops/normal.h>
|
| 1017 |
+
#include <ATen/ops/not_equal.h>
|
| 1018 |
+
#include <ATen/ops/nuclear_norm.h>
|
| 1019 |
+
#include <ATen/ops/numpy_T.h>
|
| 1020 |
+
#include <ATen/ops/one_hot.h>
|
| 1021 |
+
#include <ATen/ops/ones.h>
|
| 1022 |
+
#include <ATen/ops/ones_like.h>
|
| 1023 |
+
#include <ATen/ops/or.h>
|
| 1024 |
+
#include <ATen/ops/orgqr.h>
|
| 1025 |
+
#include <ATen/ops/ormqr.h>
|
| 1026 |
+
#include <ATen/ops/outer.h>
|
| 1027 |
+
#include <ATen/ops/output_nr.h>
|
| 1028 |
+
#include <ATen/ops/pad.h>
|
| 1029 |
+
#include <ATen/ops/pad_sequence.h>
|
| 1030 |
+
#include <ATen/ops/pairwise_distance.h>
|
| 1031 |
+
#include <ATen/ops/pdist.h>
|
| 1032 |
+
#include <ATen/ops/permute.h>
|
| 1033 |
+
#include <ATen/ops/permute_copy.h>
|
| 1034 |
+
#include <ATen/ops/pin_memory.h>
|
| 1035 |
+
#include <ATen/ops/pinverse.h>
|
| 1036 |
+
#include <ATen/ops/pixel_shuffle.h>
|
| 1037 |
+
#include <ATen/ops/pixel_unshuffle.h>
|
| 1038 |
+
#include <ATen/ops/poisson.h>
|
| 1039 |
+
#include <ATen/ops/poisson_nll_loss.h>
|
| 1040 |
+
#include <ATen/ops/polar.h>
|
| 1041 |
+
#include <ATen/ops/polygamma.h>
|
| 1042 |
+
#include <ATen/ops/positive.h>
|
| 1043 |
+
#include <ATen/ops/pow.h>
|
| 1044 |
+
#include <ATen/ops/prelu.h>
|
| 1045 |
+
#include <ATen/ops/prod.h>
|
| 1046 |
+
#include <ATen/ops/promote_types.h>
|
| 1047 |
+
#include <ATen/ops/put.h>
|
| 1048 |
+
#include <ATen/ops/q_per_channel_axis.h>
|
| 1049 |
+
#include <ATen/ops/q_per_channel_scales.h>
|
| 1050 |
+
#include <ATen/ops/q_per_channel_zero_points.h>
|
| 1051 |
+
#include <ATen/ops/q_scale.h>
|
| 1052 |
+
#include <ATen/ops/q_zero_point.h>
|
| 1053 |
+
#include <ATen/ops/qr.h>
|
| 1054 |
+
#include <ATen/ops/qscheme.h>
|
| 1055 |
+
#include <ATen/ops/quantile.h>
|
| 1056 |
+
#include <ATen/ops/quantize_per_channel.h>
|
| 1057 |
+
#include <ATen/ops/quantize_per_tensor.h>
|
| 1058 |
+
#include <ATen/ops/quantize_per_tensor_dynamic.h>
|
| 1059 |
+
#include <ATen/ops/quantized_batch_norm.h>
|
| 1060 |
+
#include <ATen/ops/quantized_gru_cell.h>
|
| 1061 |
+
#include <ATen/ops/quantized_lstm_cell.h>
|
| 1062 |
+
#include <ATen/ops/quantized_max_pool1d.h>
|
| 1063 |
+
#include <ATen/ops/quantized_max_pool2d.h>
|
| 1064 |
+
#include <ATen/ops/quantized_max_pool3d.h>
|
| 1065 |
+
#include <ATen/ops/quantized_rnn_relu_cell.h>
|
| 1066 |
+
#include <ATen/ops/quantized_rnn_tanh_cell.h>
|
| 1067 |
+
#include <ATen/ops/rad2deg.h>
|
| 1068 |
+
#include <ATen/ops/rand.h>
|
| 1069 |
+
#include <ATen/ops/rand_like.h>
|
| 1070 |
+
#include <ATen/ops/randint.h>
|
| 1071 |
+
#include <ATen/ops/randint_like.h>
|
| 1072 |
+
#include <ATen/ops/randn.h>
|
| 1073 |
+
#include <ATen/ops/randn_like.h>
|
| 1074 |
+
#include <ATen/ops/random.h>
|
| 1075 |
+
#include <ATen/ops/randperm.h>
|
| 1076 |
+
#include <ATen/ops/range.h>
|
| 1077 |
+
#include <ATen/ops/ravel.h>
|
| 1078 |
+
#include <ATen/ops/real.h>
|
| 1079 |
+
#include <ATen/ops/reciprocal.h>
|
| 1080 |
+
#include <ATen/ops/record_stream.h>
|
| 1081 |
+
#include <ATen/ops/refine_names.h>
|
| 1082 |
+
#include <ATen/ops/reflection_pad1d.h>
|
| 1083 |
+
#include <ATen/ops/reflection_pad1d_backward.h>
|
| 1084 |
+
#include <ATen/ops/reflection_pad2d.h>
|
| 1085 |
+
#include <ATen/ops/reflection_pad2d_backward.h>
|
| 1086 |
+
#include <ATen/ops/reflection_pad3d.h>
|
| 1087 |
+
#include <ATen/ops/reflection_pad3d_backward.h>
|
| 1088 |
+
#include <ATen/ops/relu.h>
|
| 1089 |
+
#include <ATen/ops/relu6.h>
|
| 1090 |
+
#include <ATen/ops/remainder.h>
|
| 1091 |
+
#include <ATen/ops/rename.h>
|
| 1092 |
+
#include <ATen/ops/renorm.h>
|
| 1093 |
+
#include <ATen/ops/repeat.h>
|
| 1094 |
+
#include <ATen/ops/repeat_interleave.h>
|
| 1095 |
+
#include <ATen/ops/replication_pad1d.h>
|
| 1096 |
+
#include <ATen/ops/replication_pad1d_backward.h>
|
| 1097 |
+
#include <ATen/ops/replication_pad2d.h>
|
| 1098 |
+
#include <ATen/ops/replication_pad2d_backward.h>
|
| 1099 |
+
#include <ATen/ops/replication_pad3d.h>
|
| 1100 |
+
#include <ATen/ops/replication_pad3d_backward.h>
|
| 1101 |
+
#include <ATen/ops/requires_grad.h>
|
| 1102 |
+
#include <ATen/ops/reshape.h>
|
| 1103 |
+
#include <ATen/ops/reshape_as.h>
|
| 1104 |
+
#include <ATen/ops/resize.h>
|
| 1105 |
+
#include <ATen/ops/resize_as.h>
|
| 1106 |
+
#include <ATen/ops/resize_as_sparse.h>
|
| 1107 |
+
#include <ATen/ops/resolve_conj.h>
|
| 1108 |
+
#include <ATen/ops/resolve_neg.h>
|
| 1109 |
+
#include <ATen/ops/result_type.h>
|
| 1110 |
+
#include <ATen/ops/retain_grad.h>
|
| 1111 |
+
#include <ATen/ops/retains_grad.h>
|
| 1112 |
+
#include <ATen/ops/rnn_relu.h>
|
| 1113 |
+
#include <ATen/ops/rnn_relu_cell.h>
|
| 1114 |
+
#include <ATen/ops/rnn_tanh.h>
|
| 1115 |
+
#include <ATen/ops/rnn_tanh_cell.h>
|
| 1116 |
+
#include <ATen/ops/roll.h>
|
| 1117 |
+
#include <ATen/ops/rot90.h>
|
| 1118 |
+
#include <ATen/ops/round.h>
|
| 1119 |
+
#include <ATen/ops/row_indices.h>
|
| 1120 |
+
#include <ATen/ops/row_indices_copy.h>
|
| 1121 |
+
#include <ATen/ops/row_stack.h>
|
| 1122 |
+
#include <ATen/ops/rrelu.h>
|
| 1123 |
+
#include <ATen/ops/rrelu_with_noise.h>
|
| 1124 |
+
#include <ATen/ops/rrelu_with_noise_backward.h>
|
| 1125 |
+
#include <ATen/ops/rshift.h>
|
| 1126 |
+
#include <ATen/ops/rsqrt.h>
|
| 1127 |
+
#include <ATen/ops/rsub.h>
|
| 1128 |
+
#include <ATen/ops/scalar_tensor.h>
|
| 1129 |
+
#include <ATen/ops/scaled_dot_product_attention.h>
|
| 1130 |
+
#include <ATen/ops/scatter.h>
|
| 1131 |
+
#include <ATen/ops/scatter_add.h>
|
| 1132 |
+
#include <ATen/ops/scatter_reduce.h>
|
| 1133 |
+
#include <ATen/ops/searchsorted.h>
|
| 1134 |
+
#include <ATen/ops/segment_reduce.h>
|
| 1135 |
+
#include <ATen/ops/select.h>
|
| 1136 |
+
#include <ATen/ops/select_backward.h>
|
| 1137 |
+
#include <ATen/ops/select_copy.h>
|
| 1138 |
+
#include <ATen/ops/select_scatter.h>
|
| 1139 |
+
#include <ATen/ops/selu.h>
|
| 1140 |
+
#include <ATen/ops/set.h>
|
| 1141 |
+
#include <ATen/ops/set_data.h>
|
| 1142 |
+
#include <ATen/ops/sgn.h>
|
| 1143 |
+
#include <ATen/ops/sigmoid.h>
|
| 1144 |
+
#include <ATen/ops/sigmoid_backward.h>
|
| 1145 |
+
#include <ATen/ops/sign.h>
|
| 1146 |
+
#include <ATen/ops/signbit.h>
|
| 1147 |
+
#include <ATen/ops/silu.h>
|
| 1148 |
+
#include <ATen/ops/silu_backward.h>
|
| 1149 |
+
#include <ATen/ops/sin.h>
|
| 1150 |
+
#include <ATen/ops/sinc.h>
|
| 1151 |
+
#include <ATen/ops/sinh.h>
|
| 1152 |
+
#include <ATen/ops/size.h>
|
| 1153 |
+
#include <ATen/ops/slice.h>
|
| 1154 |
+
#include <ATen/ops/slice_backward.h>
|
| 1155 |
+
#include <ATen/ops/slice_copy.h>
|
| 1156 |
+
#include <ATen/ops/slice_inverse.h>
|
| 1157 |
+
#include <ATen/ops/slice_scatter.h>
|
| 1158 |
+
#include <ATen/ops/slogdet.h>
|
| 1159 |
+
#include <ATen/ops/slow_conv3d.h>
|
| 1160 |
+
#include <ATen/ops/slow_conv3d_forward.h>
|
| 1161 |
+
#include <ATen/ops/slow_conv_dilated2d.h>
|
| 1162 |
+
#include <ATen/ops/slow_conv_dilated3d.h>
|
| 1163 |
+
#include <ATen/ops/slow_conv_transpose2d.h>
|
| 1164 |
+
#include <ATen/ops/slow_conv_transpose3d.h>
|
| 1165 |
+
#include <ATen/ops/smm.h>
|
| 1166 |
+
#include <ATen/ops/smooth_l1_loss.h>
|
| 1167 |
+
#include <ATen/ops/smooth_l1_loss_backward.h>
|
| 1168 |
+
#include <ATen/ops/soft_margin_loss.h>
|
| 1169 |
+
#include <ATen/ops/soft_margin_loss_backward.h>
|
| 1170 |
+
#include <ATen/ops/softmax.h>
|
| 1171 |
+
#include <ATen/ops/softplus.h>
|
| 1172 |
+
#include <ATen/ops/softplus_backward.h>
|
| 1173 |
+
#include <ATen/ops/softshrink.h>
|
| 1174 |
+
#include <ATen/ops/softshrink_backward.h>
|
| 1175 |
+
#include <ATen/ops/sort.h>
|
| 1176 |
+
#include <ATen/ops/sparse_bsc_tensor.h>
|
| 1177 |
+
#include <ATen/ops/sparse_bsr_tensor.h>
|
| 1178 |
+
#include <ATen/ops/sparse_compressed_tensor.h>
|
| 1179 |
+
#include <ATen/ops/sparse_coo_tensor.h>
|
| 1180 |
+
#include <ATen/ops/sparse_csc_tensor.h>
|
| 1181 |
+
#include <ATen/ops/sparse_csr_tensor.h>
|
| 1182 |
+
#include <ATen/ops/sparse_dim.h>
|
| 1183 |
+
#include <ATen/ops/sparse_mask.h>
|
| 1184 |
+
#include <ATen/ops/sparse_resize.h>
|
| 1185 |
+
#include <ATen/ops/sparse_resize_and_clear.h>
|
| 1186 |
+
#include <ATen/ops/sparse_sampled_addmm.h>
|
| 1187 |
+
#include <ATen/ops/special_airy_ai.h>
|
| 1188 |
+
#include <ATen/ops/special_bessel_j0.h>
|
| 1189 |
+
#include <ATen/ops/special_bessel_j1.h>
|
| 1190 |
+
#include <ATen/ops/special_bessel_y0.h>
|
| 1191 |
+
#include <ATen/ops/special_bessel_y1.h>
|
| 1192 |
+
#include <ATen/ops/special_chebyshev_polynomial_t.h>
|
| 1193 |
+
#include <ATen/ops/special_chebyshev_polynomial_u.h>
|
| 1194 |
+
#include <ATen/ops/special_chebyshev_polynomial_v.h>
|
| 1195 |
+
#include <ATen/ops/special_chebyshev_polynomial_w.h>
|
| 1196 |
+
#include <ATen/ops/special_digamma.h>
|
| 1197 |
+
#include <ATen/ops/special_entr.h>
|
| 1198 |
+
#include <ATen/ops/special_erf.h>
|
| 1199 |
+
#include <ATen/ops/special_erfc.h>
|
| 1200 |
+
#include <ATen/ops/special_erfcx.h>
|
| 1201 |
+
#include <ATen/ops/special_erfinv.h>
|
| 1202 |
+
#include <ATen/ops/special_exp2.h>
|
| 1203 |
+
#include <ATen/ops/special_expit.h>
|
| 1204 |
+
#include <ATen/ops/special_expm1.h>
|
| 1205 |
+
#include <ATen/ops/special_gammainc.h>
|
| 1206 |
+
#include <ATen/ops/special_gammaincc.h>
|
| 1207 |
+
#include <ATen/ops/special_gammaln.h>
|
| 1208 |
+
#include <ATen/ops/special_hermite_polynomial_h.h>
|
| 1209 |
+
#include <ATen/ops/special_hermite_polynomial_he.h>
|
| 1210 |
+
#include <ATen/ops/special_i0.h>
|
| 1211 |
+
#include <ATen/ops/special_i0e.h>
|
| 1212 |
+
#include <ATen/ops/special_i1.h>
|
| 1213 |
+
#include <ATen/ops/special_i1e.h>
|
| 1214 |
+
#include <ATen/ops/special_laguerre_polynomial_l.h>
|
| 1215 |
+
#include <ATen/ops/special_legendre_polynomial_p.h>
|
| 1216 |
+
#include <ATen/ops/special_log1p.h>
|
| 1217 |
+
#include <ATen/ops/special_log_ndtr.h>
|
| 1218 |
+
#include <ATen/ops/special_log_softmax.h>
|
| 1219 |
+
#include <ATen/ops/special_logit.h>
|
| 1220 |
+
#include <ATen/ops/special_logsumexp.h>
|
| 1221 |
+
#include <ATen/ops/special_modified_bessel_i0.h>
|
| 1222 |
+
#include <ATen/ops/special_modified_bessel_i1.h>
|
| 1223 |
+
#include <ATen/ops/special_modified_bessel_k0.h>
|
| 1224 |
+
#include <ATen/ops/special_modified_bessel_k1.h>
|
| 1225 |
+
#include <ATen/ops/special_multigammaln.h>
|
| 1226 |
+
#include <ATen/ops/special_ndtr.h>
|
| 1227 |
+
#include <ATen/ops/special_ndtri.h>
|
| 1228 |
+
#include <ATen/ops/special_polygamma.h>
|
| 1229 |
+
#include <ATen/ops/special_psi.h>
|
| 1230 |
+
#include <ATen/ops/special_round.h>
|
| 1231 |
+
#include <ATen/ops/special_scaled_modified_bessel_k0.h>
|
| 1232 |
+
#include <ATen/ops/special_scaled_modified_bessel_k1.h>
|
| 1233 |
+
#include <ATen/ops/special_shifted_chebyshev_polynomial_t.h>
|
| 1234 |
+
#include <ATen/ops/special_shifted_chebyshev_polynomial_u.h>
|
| 1235 |
+
#include <ATen/ops/special_shifted_chebyshev_polynomial_v.h>
|
| 1236 |
+
#include <ATen/ops/special_shifted_chebyshev_polynomial_w.h>
|
| 1237 |
+
#include <ATen/ops/special_sinc.h>
|
| 1238 |
+
#include <ATen/ops/special_softmax.h>
|
| 1239 |
+
#include <ATen/ops/special_spherical_bessel_j0.h>
|
| 1240 |
+
#include <ATen/ops/special_xlog1py.h>
|
| 1241 |
+
#include <ATen/ops/special_xlogy.h>
|
| 1242 |
+
#include <ATen/ops/special_zeta.h>
|
| 1243 |
+
#include <ATen/ops/split.h>
|
| 1244 |
+
#include <ATen/ops/split_copy.h>
|
| 1245 |
+
#include <ATen/ops/split_with_sizes.h>
|
| 1246 |
+
#include <ATen/ops/split_with_sizes_copy.h>
|
| 1247 |
+
#include <ATen/ops/sqrt.h>
|
| 1248 |
+
#include <ATen/ops/square.h>
|
| 1249 |
+
#include <ATen/ops/squeeze.h>
|
| 1250 |
+
#include <ATen/ops/squeeze_copy.h>
|
| 1251 |
+
#include <ATen/ops/sspaddmm.h>
|
| 1252 |
+
#include <ATen/ops/stack.h>
|
| 1253 |
+
#include <ATen/ops/std.h>
|
| 1254 |
+
#include <ATen/ops/std_mean.h>
|
| 1255 |
+
#include <ATen/ops/stft.h>
|
| 1256 |
+
#include <ATen/ops/stride.h>
|
| 1257 |
+
#include <ATen/ops/sub.h>
|
| 1258 |
+
#include <ATen/ops/subtract.h>
|
| 1259 |
+
#include <ATen/ops/sum.h>
|
| 1260 |
+
#include <ATen/ops/sum_to_size.h>
|
| 1261 |
+
#include <ATen/ops/svd.h>
|
| 1262 |
+
#include <ATen/ops/swapaxes.h>
|
| 1263 |
+
#include <ATen/ops/swapdims.h>
|
| 1264 |
+
#include <ATen/ops/sym_constrain_range.h>
|
| 1265 |
+
#include <ATen/ops/sym_constrain_range_for_size.h>
|
| 1266 |
+
#include <ATen/ops/sym_numel.h>
|
| 1267 |
+
#include <ATen/ops/sym_size.h>
|
| 1268 |
+
#include <ATen/ops/sym_storage_offset.h>
|
| 1269 |
+
#include <ATen/ops/sym_stride.h>
|
| 1270 |
+
#include <ATen/ops/t.h>
|
| 1271 |
+
#include <ATen/ops/t_copy.h>
|
| 1272 |
+
#include <ATen/ops/take.h>
|
| 1273 |
+
#include <ATen/ops/take_along_dim.h>
|
| 1274 |
+
#include <ATen/ops/tan.h>
|
| 1275 |
+
#include <ATen/ops/tanh.h>
|
| 1276 |
+
#include <ATen/ops/tanh_backward.h>
|
| 1277 |
+
#include <ATen/ops/tensor_split.h>
|
| 1278 |
+
#include <ATen/ops/tensordot.h>
|
| 1279 |
+
#include <ATen/ops/thnn_conv2d.h>
|
| 1280 |
+
#include <ATen/ops/threshold.h>
|
| 1281 |
+
#include <ATen/ops/threshold_backward.h>
|
| 1282 |
+
#include <ATen/ops/tile.h>
|
| 1283 |
+
#include <ATen/ops/to.h>
|
| 1284 |
+
#include <ATen/ops/to_dense.h>
|
| 1285 |
+
#include <ATen/ops/to_dense_backward.h>
|
| 1286 |
+
#include <ATen/ops/to_mkldnn.h>
|
| 1287 |
+
#include <ATen/ops/to_mkldnn_backward.h>
|
| 1288 |
+
#include <ATen/ops/to_padded_tensor.h>
|
| 1289 |
+
#include <ATen/ops/to_sparse.h>
|
| 1290 |
+
#include <ATen/ops/to_sparse_bsc.h>
|
| 1291 |
+
#include <ATen/ops/to_sparse_bsr.h>
|
| 1292 |
+
#include <ATen/ops/to_sparse_csc.h>
|
| 1293 |
+
#include <ATen/ops/to_sparse_csr.h>
|
| 1294 |
+
#include <ATen/ops/topk.h>
|
| 1295 |
+
#include <ATen/ops/trace.h>
|
| 1296 |
+
#include <ATen/ops/trace_backward.h>
|
| 1297 |
+
#include <ATen/ops/transpose.h>
|
| 1298 |
+
#include <ATen/ops/transpose_copy.h>
|
| 1299 |
+
#include <ATen/ops/trapezoid.h>
|
| 1300 |
+
#include <ATen/ops/trapz.h>
|
| 1301 |
+
#include <ATen/ops/triangular_solve.h>
|
| 1302 |
+
#include <ATen/ops/tril.h>
|
| 1303 |
+
#include <ATen/ops/tril_indices.h>
|
| 1304 |
+
#include <ATen/ops/triplet_margin_loss.h>
|
| 1305 |
+
#include <ATen/ops/triu.h>
|
| 1306 |
+
#include <ATen/ops/triu_indices.h>
|
| 1307 |
+
#include <ATen/ops/true_divide.h>
|
| 1308 |
+
#include <ATen/ops/trunc.h>
|
| 1309 |
+
#include <ATen/ops/type_as.h>
|
| 1310 |
+
#include <ATen/ops/unbind.h>
|
| 1311 |
+
#include <ATen/ops/unbind_copy.h>
|
| 1312 |
+
#include <ATen/ops/unflatten.h>
|
| 1313 |
+
#include <ATen/ops/unflatten_dense_tensors.h>
|
| 1314 |
+
#include <ATen/ops/unfold.h>
|
| 1315 |
+
#include <ATen/ops/unfold_backward.h>
|
| 1316 |
+
#include <ATen/ops/unfold_copy.h>
|
| 1317 |
+
#include <ATen/ops/uniform.h>
|
| 1318 |
+
#include <ATen/ops/unique_consecutive.h>
|
| 1319 |
+
#include <ATen/ops/unique_dim.h>
|
| 1320 |
+
#include <ATen/ops/unique_dim_consecutive.h>
|
| 1321 |
+
#include <ATen/ops/unsafe_chunk.h>
|
| 1322 |
+
#include <ATen/ops/unsafe_split.h>
|
| 1323 |
+
#include <ATen/ops/unsafe_split_with_sizes.h>
|
| 1324 |
+
#include <ATen/ops/unsqueeze.h>
|
| 1325 |
+
#include <ATen/ops/unsqueeze_copy.h>
|
| 1326 |
+
#include <ATen/ops/upsample_bicubic2d.h>
|
| 1327 |
+
#include <ATen/ops/upsample_bicubic2d_backward.h>
|
| 1328 |
+
#include <ATen/ops/upsample_bilinear2d.h>
|
| 1329 |
+
#include <ATen/ops/upsample_bilinear2d_backward.h>
|
| 1330 |
+
#include <ATen/ops/upsample_linear1d.h>
|
| 1331 |
+
#include <ATen/ops/upsample_linear1d_backward.h>
|
| 1332 |
+
#include <ATen/ops/upsample_nearest1d.h>
|
| 1333 |
+
#include <ATen/ops/upsample_nearest1d_backward.h>
|
| 1334 |
+
#include <ATen/ops/upsample_nearest2d.h>
|
| 1335 |
+
#include <ATen/ops/upsample_nearest2d_backward.h>
|
| 1336 |
+
#include <ATen/ops/upsample_nearest3d.h>
|
| 1337 |
+
#include <ATen/ops/upsample_nearest3d_backward.h>
|
| 1338 |
+
#include <ATen/ops/upsample_trilinear3d.h>
|
| 1339 |
+
#include <ATen/ops/upsample_trilinear3d_backward.h>
|
| 1340 |
+
#include <ATen/ops/value_selecting_reduction_backward.h>
|
| 1341 |
+
#include <ATen/ops/values.h>
|
| 1342 |
+
#include <ATen/ops/values_copy.h>
|
| 1343 |
+
#include <ATen/ops/vander.h>
|
| 1344 |
+
#include <ATen/ops/var.h>
|
| 1345 |
+
#include <ATen/ops/var_mean.h>
|
| 1346 |
+
#include <ATen/ops/vdot.h>
|
| 1347 |
+
#include <ATen/ops/view.h>
|
| 1348 |
+
#include <ATen/ops/view_as.h>
|
| 1349 |
+
#include <ATen/ops/view_as_complex.h>
|
| 1350 |
+
#include <ATen/ops/view_as_complex_copy.h>
|
| 1351 |
+
#include <ATen/ops/view_as_real.h>
|
| 1352 |
+
#include <ATen/ops/view_as_real_copy.h>
|
| 1353 |
+
#include <ATen/ops/view_copy.h>
|
| 1354 |
+
#include <ATen/ops/vsplit.h>
|
| 1355 |
+
#include <ATen/ops/vstack.h>
|
| 1356 |
+
#include <ATen/ops/where.h>
|
| 1357 |
+
#include <ATen/ops/xlogy.h>
|
| 1358 |
+
#include <ATen/ops/xor.h>
|
| 1359 |
+
#include <ATen/ops/zero.h>
|
| 1360 |
+
#include <ATen/ops/zeros.h>
|
| 1361 |
+
#include <ATen/ops/zeros_like.h>
|
| 1362 |
+
|
| 1363 |
+
namespace at {
|
| 1364 |
+
|
| 1365 |
+
|
| 1366 |
+
|
| 1367 |
+
// Special C++ only overloads for std()-like functions (See gh-40287)
|
| 1368 |
+
// These are needed because int -> bool conversion takes precedence over int -> IntArrayRef
|
| 1369 |
+
// So, for example std(0) would select the std(unbiased=False) overload
|
| 1370 |
+
TORCH_API inline Tensor var(const Tensor& self, int dim) {
|
| 1371 |
+
return at::var(self, IntArrayRef{dim});
|
| 1372 |
+
}
|
| 1373 |
+
TORCH_API inline std::tuple<Tensor, Tensor> var_mean(const Tensor& self, int dim) {
|
| 1374 |
+
return at::var_mean(self, IntArrayRef{dim});
|
| 1375 |
+
}
|
| 1376 |
+
TORCH_API inline Tensor std(const Tensor& self, int dim) {
|
| 1377 |
+
return at::std(self, IntArrayRef{dim});
|
| 1378 |
+
}
|
| 1379 |
+
TORCH_API inline std::tuple<Tensor, Tensor> std_mean(const Tensor& self, int dim) {
|
| 1380 |
+
return at::std_mean(self, IntArrayRef{dim});
|
| 1381 |
+
}
|
| 1382 |
+
|
| 1383 |
+
inline int64_t numel(const Tensor& tensor) {
|
| 1384 |
+
return tensor.numel();
|
| 1385 |
+
}
|
| 1386 |
+
|
| 1387 |
+
inline int64_t size(const Tensor& tensor, int64_t dim) {
|
| 1388 |
+
return tensor.size(dim);
|
| 1389 |
+
}
|
| 1390 |
+
|
| 1391 |
+
inline int64_t stride(const Tensor& tensor, int64_t dim) {
|
| 1392 |
+
return tensor.stride(dim);
|
| 1393 |
+
}
|
| 1394 |
+
|
| 1395 |
+
inline bool is_complex(const Tensor& tensor) {
|
| 1396 |
+
return tensor.is_complex();
|
| 1397 |
+
}
|
| 1398 |
+
|
| 1399 |
+
inline bool is_floating_point(const Tensor& tensor) {
|
| 1400 |
+
return tensor.is_floating_point();
|
| 1401 |
+
}
|
| 1402 |
+
|
| 1403 |
+
inline bool is_signed(const Tensor& tensor) {
|
| 1404 |
+
return tensor.is_signed();
|
| 1405 |
+
}
|
| 1406 |
+
|
| 1407 |
+
inline bool is_inference(const Tensor& tensor) {
|
| 1408 |
+
return tensor.is_inference();
|
| 1409 |
+
}
|
| 1410 |
+
|
| 1411 |
+
inline bool _is_zerotensor(const Tensor& tensor) {
|
| 1412 |
+
return tensor._is_zerotensor();
|
| 1413 |
+
}
|
| 1414 |
+
|
| 1415 |
+
inline bool is_conj(const Tensor& tensor) {
|
| 1416 |
+
return tensor.is_conj();
|
| 1417 |
+
}
|
| 1418 |
+
|
| 1419 |
+
inline Tensor conj(const Tensor& tensor) {
|
| 1420 |
+
return tensor.conj();
|
| 1421 |
+
}
|
| 1422 |
+
|
| 1423 |
+
inline bool is_neg(const Tensor& tensor) {
|
| 1424 |
+
return tensor.is_neg();
|
| 1425 |
+
}
|
| 1426 |
+
|
| 1427 |
+
}
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/MemoryOverlap.h
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <c10/macros/Export.h>
|
| 4 |
+
|
| 5 |
+
namespace c10 {
|
| 6 |
+
struct TensorImpl;
|
| 7 |
+
}
|
| 8 |
+
|
| 9 |
+
namespace at {
|
| 10 |
+
class TensorBase;
|
| 11 |
+
|
| 12 |
+
// MemOverlap: Whether or not there is memory overlap
|
| 13 |
+
//
|
| 14 |
+
// No: Absolutely no memory overlap
|
| 15 |
+
// Yes: Absolutely yes memory overlap
|
| 16 |
+
// TooHard: There might be memory overlap, but it was too expensive to compute.
|
| 17 |
+
//
|
| 18 |
+
// NB: Please update the python test for these if you renumber them.
|
| 19 |
+
enum class MemOverlap { No, Yes, TooHard };
|
| 20 |
+
|
| 21 |
+
enum class MemOverlapStatus { Full, Partial, No, TooHard };
|
| 22 |
+
|
| 23 |
+
TORCH_API MemOverlap has_internal_overlap(const TensorBase& t);
|
| 24 |
+
TORCH_API MemOverlap has_internal_overlap(c10::TensorImpl* t);
|
| 25 |
+
|
| 26 |
+
TORCH_API void assert_no_internal_overlap(const TensorBase& t);
|
| 27 |
+
TORCH_API void assert_no_internal_overlap(c10::TensorImpl* t);
|
| 28 |
+
|
| 29 |
+
TORCH_API MemOverlapStatus
|
| 30 |
+
get_overlap_status(const TensorBase& a, const TensorBase& b);
|
| 31 |
+
TORCH_API MemOverlapStatus
|
| 32 |
+
get_overlap_status(const c10::TensorImpl* a, const c10::TensorImpl* b);
|
| 33 |
+
|
| 34 |
+
TORCH_API void assert_no_partial_overlap(
|
| 35 |
+
const TensorBase& a,
|
| 36 |
+
const TensorBase& b);
|
| 37 |
+
void assert_no_partial_overlap(c10::TensorImpl* a, c10::TensorImpl* b);
|
| 38 |
+
|
| 39 |
+
TORCH_API void assert_no_overlap(const TensorBase& a, const TensorBase& b);
|
| 40 |
+
TORCH_API void assert_no_overlap(c10::TensorImpl* a, c10::TensorImpl* b);
|
| 41 |
+
|
| 42 |
+
} // namespace at
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/NativeMetaFunctions.h
ADDED
|
@@ -0,0 +1,1303 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
// @generated by torchgen/gen.py from NativeMetaFunctions.h
|
| 4 |
+
|
| 5 |
+
#include <ATen/core/Tensor.h>
|
| 6 |
+
#include <ATen/core/IListRef.h>
|
| 7 |
+
#include <ATen/TensorMeta.h>
|
| 8 |
+
#include <ATen/TensorIterator.h>
|
| 9 |
+
|
| 10 |
+
#include <ATen/ops/_adaptive_avg_pool2d_meta.h>
|
| 11 |
+
#include <ATen/ops/_adaptive_avg_pool2d_backward_meta.h>
|
| 12 |
+
#include <ATen/ops/_adaptive_avg_pool3d_meta.h>
|
| 13 |
+
#include <ATen/ops/_adaptive_avg_pool3d_backward_meta.h>
|
| 14 |
+
#include <ATen/ops/_add_batch_dim_meta.h>
|
| 15 |
+
#include <ATen/ops/_add_relu_meta.h>
|
| 16 |
+
#include <ATen/ops/_addmm_activation_meta.h>
|
| 17 |
+
#include <ATen/ops/_aminmax_meta.h>
|
| 18 |
+
#include <ATen/ops/_amp_foreach_non_finite_check_and_unscale_meta.h>
|
| 19 |
+
#include <ATen/ops/_amp_update_scale_meta.h>
|
| 20 |
+
#include <ATen/ops/_assert_async_meta.h>
|
| 21 |
+
#include <ATen/ops/_assert_scalar_meta.h>
|
| 22 |
+
#include <ATen/ops/_assert_tensor_metadata_meta.h>
|
| 23 |
+
#include <ATen/ops/_autocast_to_full_precision_meta.h>
|
| 24 |
+
#include <ATen/ops/_autocast_to_reduced_precision_meta.h>
|
| 25 |
+
#include <ATen/ops/_backward_meta.h>
|
| 26 |
+
#include <ATen/ops/_batch_norm_impl_index_meta.h>
|
| 27 |
+
#include <ATen/ops/_batch_norm_impl_index_backward_meta.h>
|
| 28 |
+
#include <ATen/ops/_cast_Byte_meta.h>
|
| 29 |
+
#include <ATen/ops/_cast_Char_meta.h>
|
| 30 |
+
#include <ATen/ops/_cast_Double_meta.h>
|
| 31 |
+
#include <ATen/ops/_cast_Float_meta.h>
|
| 32 |
+
#include <ATen/ops/_cast_Half_meta.h>
|
| 33 |
+
#include <ATen/ops/_cast_Int_meta.h>
|
| 34 |
+
#include <ATen/ops/_cast_Long_meta.h>
|
| 35 |
+
#include <ATen/ops/_cast_Short_meta.h>
|
| 36 |
+
#include <ATen/ops/_cdist_backward_meta.h>
|
| 37 |
+
#include <ATen/ops/_cdist_forward_meta.h>
|
| 38 |
+
#include <ATen/ops/_cholesky_solve_helper_meta.h>
|
| 39 |
+
#include <ATen/ops/_choose_qparams_per_tensor_meta.h>
|
| 40 |
+
#include <ATen/ops/_chunk_cat_meta.h>
|
| 41 |
+
#include <ATen/ops/_coalesce_meta.h>
|
| 42 |
+
#include <ATen/ops/_coalesced_meta.h>
|
| 43 |
+
#include <ATen/ops/_compute_linear_combination_meta.h>
|
| 44 |
+
#include <ATen/ops/_conj_meta.h>
|
| 45 |
+
#include <ATen/ops/_conj_copy_meta.h>
|
| 46 |
+
#include <ATen/ops/_conj_physical_meta.h>
|
| 47 |
+
#include <ATen/ops/_conv_depthwise2d_meta.h>
|
| 48 |
+
#include <ATen/ops/_convert_indices_from_coo_to_csr_meta.h>
|
| 49 |
+
#include <ATen/ops/_convert_indices_from_csr_to_coo_meta.h>
|
| 50 |
+
#include <ATen/ops/_convert_weight_to_int4pack_meta.h>
|
| 51 |
+
#include <ATen/ops/_convolution_meta.h>
|
| 52 |
+
#include <ATen/ops/_convolution_double_backward_meta.h>
|
| 53 |
+
#include <ATen/ops/_convolution_mode_meta.h>
|
| 54 |
+
#include <ATen/ops/_copy_from_meta.h>
|
| 55 |
+
#include <ATen/ops/_copy_from_and_resize_meta.h>
|
| 56 |
+
#include <ATen/ops/_cslt_compress_meta.h>
|
| 57 |
+
#include <ATen/ops/_cslt_sparse_mm_meta.h>
|
| 58 |
+
#include <ATen/ops/_cslt_sparse_mm_search_meta.h>
|
| 59 |
+
#include <ATen/ops/_ctc_loss_meta.h>
|
| 60 |
+
#include <ATen/ops/_ctc_loss_backward_meta.h>
|
| 61 |
+
#include <ATen/ops/_cudnn_ctc_loss_meta.h>
|
| 62 |
+
#include <ATen/ops/_cudnn_init_dropout_state_meta.h>
|
| 63 |
+
#include <ATen/ops/_cudnn_rnn_meta.h>
|
| 64 |
+
#include <ATen/ops/_cudnn_rnn_backward_meta.h>
|
| 65 |
+
#include <ATen/ops/_cudnn_rnn_flatten_weight_meta.h>
|
| 66 |
+
#include <ATen/ops/_cufft_clear_plan_cache_meta.h>
|
| 67 |
+
#include <ATen/ops/_cufft_get_plan_cache_max_size_meta.h>
|
| 68 |
+
#include <ATen/ops/_cufft_get_plan_cache_size_meta.h>
|
| 69 |
+
#include <ATen/ops/_cufft_set_plan_cache_max_size_meta.h>
|
| 70 |
+
#include <ATen/ops/_cummax_helper_meta.h>
|
| 71 |
+
#include <ATen/ops/_cummin_helper_meta.h>
|
| 72 |
+
#include <ATen/ops/_debug_has_internal_overlap_meta.h>
|
| 73 |
+
#include <ATen/ops/_dimI_meta.h>
|
| 74 |
+
#include <ATen/ops/_dimV_meta.h>
|
| 75 |
+
#include <ATen/ops/_dim_arange_meta.h>
|
| 76 |
+
#include <ATen/ops/_dirichlet_grad_meta.h>
|
| 77 |
+
#include <ATen/ops/_efficient_attention_backward_meta.h>
|
| 78 |
+
#include <ATen/ops/_efficient_attention_forward_meta.h>
|
| 79 |
+
#include <ATen/ops/_efficientzerotensor_meta.h>
|
| 80 |
+
#include <ATen/ops/_embedding_bag_meta.h>
|
| 81 |
+
#include <ATen/ops/_embedding_bag_backward_meta.h>
|
| 82 |
+
#include <ATen/ops/_embedding_bag_dense_backward_meta.h>
|
| 83 |
+
#include <ATen/ops/_embedding_bag_forward_only_meta.h>
|
| 84 |
+
#include <ATen/ops/_embedding_bag_per_sample_weights_backward_meta.h>
|
| 85 |
+
#include <ATen/ops/_embedding_bag_sparse_backward_meta.h>
|
| 86 |
+
#include <ATen/ops/_empty_affine_quantized_meta.h>
|
| 87 |
+
#include <ATen/ops/_empty_per_channel_affine_quantized_meta.h>
|
| 88 |
+
#include <ATen/ops/_euclidean_dist_meta.h>
|
| 89 |
+
#include <ATen/ops/_fake_quantize_learnable_per_channel_affine_meta.h>
|
| 90 |
+
#include <ATen/ops/_fake_quantize_learnable_per_channel_affine_backward_meta.h>
|
| 91 |
+
#include <ATen/ops/_fake_quantize_learnable_per_tensor_affine_meta.h>
|
| 92 |
+
#include <ATen/ops/_fake_quantize_learnable_per_tensor_affine_backward_meta.h>
|
| 93 |
+
#include <ATen/ops/_fake_quantize_per_tensor_affine_cachemask_tensor_qparams_meta.h>
|
| 94 |
+
#include <ATen/ops/_fft_c2c_meta.h>
|
| 95 |
+
#include <ATen/ops/_fft_c2r_meta.h>
|
| 96 |
+
#include <ATen/ops/_fft_r2c_meta.h>
|
| 97 |
+
#include <ATen/ops/_fill_mem_eff_dropout_mask_meta.h>
|
| 98 |
+
#include <ATen/ops/_flash_attention_backward_meta.h>
|
| 99 |
+
#include <ATen/ops/_flash_attention_forward_meta.h>
|
| 100 |
+
#include <ATen/ops/_foobar_meta.h>
|
| 101 |
+
#include <ATen/ops/_foreach_abs_meta.h>
|
| 102 |
+
#include <ATen/ops/_foreach_acos_meta.h>
|
| 103 |
+
#include <ATen/ops/_foreach_add_meta.h>
|
| 104 |
+
#include <ATen/ops/_foreach_addcdiv_meta.h>
|
| 105 |
+
#include <ATen/ops/_foreach_addcmul_meta.h>
|
| 106 |
+
#include <ATen/ops/_foreach_asin_meta.h>
|
| 107 |
+
#include <ATen/ops/_foreach_atan_meta.h>
|
| 108 |
+
#include <ATen/ops/_foreach_ceil_meta.h>
|
| 109 |
+
#include <ATen/ops/_foreach_clamp_max_meta.h>
|
| 110 |
+
#include <ATen/ops/_foreach_clamp_min_meta.h>
|
| 111 |
+
#include <ATen/ops/_foreach_copy_meta.h>
|
| 112 |
+
#include <ATen/ops/_foreach_cos_meta.h>
|
| 113 |
+
#include <ATen/ops/_foreach_cosh_meta.h>
|
| 114 |
+
#include <ATen/ops/_foreach_div_meta.h>
|
| 115 |
+
#include <ATen/ops/_foreach_erf_meta.h>
|
| 116 |
+
#include <ATen/ops/_foreach_erfc_meta.h>
|
| 117 |
+
#include <ATen/ops/_foreach_exp_meta.h>
|
| 118 |
+
#include <ATen/ops/_foreach_expm1_meta.h>
|
| 119 |
+
#include <ATen/ops/_foreach_floor_meta.h>
|
| 120 |
+
#include <ATen/ops/_foreach_frac_meta.h>
|
| 121 |
+
#include <ATen/ops/_foreach_lerp_meta.h>
|
| 122 |
+
#include <ATen/ops/_foreach_lgamma_meta.h>
|
| 123 |
+
#include <ATen/ops/_foreach_log_meta.h>
|
| 124 |
+
#include <ATen/ops/_foreach_log10_meta.h>
|
| 125 |
+
#include <ATen/ops/_foreach_log1p_meta.h>
|
| 126 |
+
#include <ATen/ops/_foreach_log2_meta.h>
|
| 127 |
+
#include <ATen/ops/_foreach_maximum_meta.h>
|
| 128 |
+
#include <ATen/ops/_foreach_minimum_meta.h>
|
| 129 |
+
#include <ATen/ops/_foreach_mul_meta.h>
|
| 130 |
+
#include <ATen/ops/_foreach_neg_meta.h>
|
| 131 |
+
#include <ATen/ops/_foreach_norm_meta.h>
|
| 132 |
+
#include <ATen/ops/_foreach_pow_meta.h>
|
| 133 |
+
#include <ATen/ops/_foreach_reciprocal_meta.h>
|
| 134 |
+
#include <ATen/ops/_foreach_round_meta.h>
|
| 135 |
+
#include <ATen/ops/_foreach_sigmoid_meta.h>
|
| 136 |
+
#include <ATen/ops/_foreach_sign_meta.h>
|
| 137 |
+
#include <ATen/ops/_foreach_sin_meta.h>
|
| 138 |
+
#include <ATen/ops/_foreach_sinh_meta.h>
|
| 139 |
+
#include <ATen/ops/_foreach_sqrt_meta.h>
|
| 140 |
+
#include <ATen/ops/_foreach_sub_meta.h>
|
| 141 |
+
#include <ATen/ops/_foreach_tan_meta.h>
|
| 142 |
+
#include <ATen/ops/_foreach_tanh_meta.h>
|
| 143 |
+
#include <ATen/ops/_foreach_trunc_meta.h>
|
| 144 |
+
#include <ATen/ops/_foreach_zero_meta.h>
|
| 145 |
+
#include <ATen/ops/_functional_assert_async_meta.h>
|
| 146 |
+
#include <ATen/ops/_functional_assert_scalar_meta.h>
|
| 147 |
+
#include <ATen/ops/_functional_sym_constrain_range_meta.h>
|
| 148 |
+
#include <ATen/ops/_functional_sym_constrain_range_for_size_meta.h>
|
| 149 |
+
#include <ATen/ops/_fused_adam_meta.h>
|
| 150 |
+
#include <ATen/ops/_fused_adamw_meta.h>
|
| 151 |
+
#include <ATen/ops/_fused_dropout_meta.h>
|
| 152 |
+
#include <ATen/ops/_fused_moving_avg_obs_fq_helper_meta.h>
|
| 153 |
+
#include <ATen/ops/_fused_sdp_choice_meta.h>
|
| 154 |
+
#include <ATen/ops/_fused_sgd_meta.h>
|
| 155 |
+
#include <ATen/ops/_fw_primal_meta.h>
|
| 156 |
+
#include <ATen/ops/_fw_primal_copy_meta.h>
|
| 157 |
+
#include <ATen/ops/_gather_sparse_backward_meta.h>
|
| 158 |
+
#include <ATen/ops/_grid_sampler_2d_cpu_fallback_meta.h>
|
| 159 |
+
#include <ATen/ops/_grid_sampler_2d_cpu_fallback_backward_meta.h>
|
| 160 |
+
#include <ATen/ops/_has_compatible_shallow_copy_type_meta.h>
|
| 161 |
+
#include <ATen/ops/_has_same_storage_numel_meta.h>
|
| 162 |
+
#include <ATen/ops/_histogramdd_bin_edges_meta.h>
|
| 163 |
+
#include <ATen/ops/_histogramdd_from_bin_cts_meta.h>
|
| 164 |
+
#include <ATen/ops/_histogramdd_from_bin_tensors_meta.h>
|
| 165 |
+
#include <ATen/ops/_index_put_impl_meta.h>
|
| 166 |
+
#include <ATen/ops/_indices_meta.h>
|
| 167 |
+
#include <ATen/ops/_indices_copy_meta.h>
|
| 168 |
+
#include <ATen/ops/_int_mm_meta.h>
|
| 169 |
+
#include <ATen/ops/_is_all_true_meta.h>
|
| 170 |
+
#include <ATen/ops/_is_any_true_meta.h>
|
| 171 |
+
#include <ATen/ops/_is_zerotensor_meta.h>
|
| 172 |
+
#include <ATen/ops/_lazy_clone_meta.h>
|
| 173 |
+
#include <ATen/ops/_linalg_check_errors_meta.h>
|
| 174 |
+
#include <ATen/ops/_linalg_det_meta.h>
|
| 175 |
+
#include <ATen/ops/_linalg_eigh_meta.h>
|
| 176 |
+
#include <ATen/ops/_linalg_eigvals_meta.h>
|
| 177 |
+
#include <ATen/ops/_linalg_slogdet_meta.h>
|
| 178 |
+
#include <ATen/ops/_linalg_solve_ex_meta.h>
|
| 179 |
+
#include <ATen/ops/_linalg_svd_meta.h>
|
| 180 |
+
#include <ATen/ops/_local_scalar_dense_meta.h>
|
| 181 |
+
#include <ATen/ops/_log_softmax_meta.h>
|
| 182 |
+
#include <ATen/ops/_log_softmax_backward_data_meta.h>
|
| 183 |
+
#include <ATen/ops/_logcumsumexp_meta.h>
|
| 184 |
+
#include <ATen/ops/_lstm_mps_meta.h>
|
| 185 |
+
#include <ATen/ops/_lu_with_info_meta.h>
|
| 186 |
+
#include <ATen/ops/_make_dep_token_meta.h>
|
| 187 |
+
#include <ATen/ops/_make_dual_meta.h>
|
| 188 |
+
#include <ATen/ops/_make_dual_copy_meta.h>
|
| 189 |
+
#include <ATen/ops/_make_per_channel_quantized_tensor_meta.h>
|
| 190 |
+
#include <ATen/ops/_make_per_tensor_quantized_tensor_meta.h>
|
| 191 |
+
#include <ATen/ops/_masked_scale_meta.h>
|
| 192 |
+
#include <ATen/ops/_masked_softmax_meta.h>
|
| 193 |
+
#include <ATen/ops/_masked_softmax_backward_meta.h>
|
| 194 |
+
#include <ATen/ops/_mixed_dtypes_linear_meta.h>
|
| 195 |
+
#include <ATen/ops/_mkldnn_reshape_meta.h>
|
| 196 |
+
#include <ATen/ops/_mkldnn_transpose_meta.h>
|
| 197 |
+
#include <ATen/ops/_mps_convolution_meta.h>
|
| 198 |
+
#include <ATen/ops/_mps_convolution_transpose_meta.h>
|
| 199 |
+
#include <ATen/ops/_native_batch_norm_legit_meta.h>
|
| 200 |
+
#include <ATen/ops/_native_batch_norm_legit_no_training_meta.h>
|
| 201 |
+
#include <ATen/ops/_native_multi_head_attention_meta.h>
|
| 202 |
+
#include <ATen/ops/_neg_view_meta.h>
|
| 203 |
+
#include <ATen/ops/_neg_view_copy_meta.h>
|
| 204 |
+
#include <ATen/ops/_nested_from_padded_meta.h>
|
| 205 |
+
#include <ATen/ops/_nested_from_padded_and_nested_example_meta.h>
|
| 206 |
+
#include <ATen/ops/_nested_get_jagged_dummy_meta.h>
|
| 207 |
+
#include <ATen/ops/_nested_get_lengths_meta.h>
|
| 208 |
+
#include <ATen/ops/_nested_get_offsets_meta.h>
|
| 209 |
+
#include <ATen/ops/_nested_get_ragged_idx_meta.h>
|
| 210 |
+
#include <ATen/ops/_nested_get_values_meta.h>
|
| 211 |
+
#include <ATen/ops/_nested_get_values_copy_meta.h>
|
| 212 |
+
#include <ATen/ops/_nested_select_backward_meta.h>
|
| 213 |
+
#include <ATen/ops/_nested_sum_backward_meta.h>
|
| 214 |
+
#include <ATen/ops/_nested_tensor_from_mask_meta.h>
|
| 215 |
+
#include <ATen/ops/_nested_tensor_from_mask_left_aligned_meta.h>
|
| 216 |
+
#include <ATen/ops/_nested_tensor_from_tensor_list_meta.h>
|
| 217 |
+
#include <ATen/ops/_nested_tensor_size_meta.h>
|
| 218 |
+
#include <ATen/ops/_nested_tensor_softmax_with_shape_meta.h>
|
| 219 |
+
#include <ATen/ops/_nested_tensor_storage_offsets_meta.h>
|
| 220 |
+
#include <ATen/ops/_nested_tensor_strides_meta.h>
|
| 221 |
+
#include <ATen/ops/_nested_view_from_buffer_meta.h>
|
| 222 |
+
#include <ATen/ops/_nested_view_from_buffer_copy_meta.h>
|
| 223 |
+
#include <ATen/ops/_nested_view_from_jagged_meta.h>
|
| 224 |
+
#include <ATen/ops/_nested_view_from_jagged_copy_meta.h>
|
| 225 |
+
#include <ATen/ops/_new_zeros_with_same_feature_meta_meta.h>
|
| 226 |
+
#include <ATen/ops/_nnpack_available_meta.h>
|
| 227 |
+
#include <ATen/ops/_nnpack_spatial_convolution_meta.h>
|
| 228 |
+
#include <ATen/ops/_nnz_meta.h>
|
| 229 |
+
#include <ATen/ops/_pack_padded_sequence_meta.h>
|
| 230 |
+
#include <ATen/ops/_pack_padded_sequence_backward_meta.h>
|
| 231 |
+
#include <ATen/ops/_pad_circular_meta.h>
|
| 232 |
+
#include <ATen/ops/_pad_enum_meta.h>
|
| 233 |
+
#include <ATen/ops/_pad_packed_sequence_meta.h>
|
| 234 |
+
#include <ATen/ops/_pdist_backward_meta.h>
|
| 235 |
+
#include <ATen/ops/_pdist_forward_meta.h>
|
| 236 |
+
#include <ATen/ops/_pin_memory_meta.h>
|
| 237 |
+
#include <ATen/ops/_prelu_kernel_meta.h>
|
| 238 |
+
#include <ATen/ops/_prelu_kernel_backward_meta.h>
|
| 239 |
+
#include <ATen/ops/_print_meta.h>
|
| 240 |
+
#include <ATen/ops/_propagate_xla_data_meta.h>
|
| 241 |
+
#include <ATen/ops/_remove_batch_dim_meta.h>
|
| 242 |
+
#include <ATen/ops/_reshape_alias_meta.h>
|
| 243 |
+
#include <ATen/ops/_reshape_alias_copy_meta.h>
|
| 244 |
+
#include <ATen/ops/_reshape_copy_meta.h>
|
| 245 |
+
#include <ATen/ops/_reshape_from_tensor_meta.h>
|
| 246 |
+
#include <ATen/ops/_resize_output_meta.h>
|
| 247 |
+
#include <ATen/ops/_rowwise_prune_meta.h>
|
| 248 |
+
#include <ATen/ops/_sample_dirichlet_meta.h>
|
| 249 |
+
#include <ATen/ops/_saturate_weight_to_fp16_meta.h>
|
| 250 |
+
#include <ATen/ops/_scaled_dot_product_attention_math_meta.h>
|
| 251 |
+
#include <ATen/ops/_scaled_dot_product_cudnn_attention_meta.h>
|
| 252 |
+
#include <ATen/ops/_scaled_dot_product_efficient_attention_meta.h>
|
| 253 |
+
#include <ATen/ops/_scaled_dot_product_efficient_attention_backward_meta.h>
|
| 254 |
+
#include <ATen/ops/_scaled_dot_product_flash_attention_meta.h>
|
| 255 |
+
#include <ATen/ops/_scaled_dot_product_flash_attention_backward_meta.h>
|
| 256 |
+
#include <ATen/ops/_scaled_dot_product_flash_attention_for_cpu_meta.h>
|
| 257 |
+
#include <ATen/ops/_scaled_dot_product_flash_attention_for_cpu_backward_meta.h>
|
| 258 |
+
#include <ATen/ops/_scaled_mm_meta.h>
|
| 259 |
+
#include <ATen/ops/_segment_reduce_backward_meta.h>
|
| 260 |
+
#include <ATen/ops/_shape_as_tensor_meta.h>
|
| 261 |
+
#include <ATen/ops/_slow_conv2d_backward_meta.h>
|
| 262 |
+
#include <ATen/ops/_slow_conv2d_forward_meta.h>
|
| 263 |
+
#include <ATen/ops/_sobol_engine_draw_meta.h>
|
| 264 |
+
#include <ATen/ops/_sobol_engine_ff_meta.h>
|
| 265 |
+
#include <ATen/ops/_sobol_engine_initialize_state_meta.h>
|
| 266 |
+
#include <ATen/ops/_sobol_engine_scramble_meta.h>
|
| 267 |
+
#include <ATen/ops/_softmax_meta.h>
|
| 268 |
+
#include <ATen/ops/_softmax_backward_data_meta.h>
|
| 269 |
+
#include <ATen/ops/_sparse_addmm_meta.h>
|
| 270 |
+
#include <ATen/ops/_sparse_broadcast_to_meta.h>
|
| 271 |
+
#include <ATen/ops/_sparse_broadcast_to_copy_meta.h>
|
| 272 |
+
#include <ATen/ops/_sparse_bsc_tensor_unsafe_meta.h>
|
| 273 |
+
#include <ATen/ops/_sparse_bsr_tensor_unsafe_meta.h>
|
| 274 |
+
#include <ATen/ops/_sparse_compressed_tensor_unsafe_meta.h>
|
| 275 |
+
#include <ATen/ops/_sparse_coo_tensor_unsafe_meta.h>
|
| 276 |
+
#include <ATen/ops/_sparse_coo_tensor_with_dims_meta.h>
|
| 277 |
+
#include <ATen/ops/_sparse_coo_tensor_with_dims_and_tensors_meta.h>
|
| 278 |
+
#include <ATen/ops/_sparse_csc_tensor_unsafe_meta.h>
|
| 279 |
+
#include <ATen/ops/_sparse_csr_prod_meta.h>
|
| 280 |
+
#include <ATen/ops/_sparse_csr_sum_meta.h>
|
| 281 |
+
#include <ATen/ops/_sparse_csr_tensor_unsafe_meta.h>
|
| 282 |
+
#include <ATen/ops/_sparse_log_softmax_meta.h>
|
| 283 |
+
#include <ATen/ops/_sparse_log_softmax_backward_data_meta.h>
|
| 284 |
+
#include <ATen/ops/_sparse_mask_projection_meta.h>
|
| 285 |
+
#include <ATen/ops/_sparse_mm_meta.h>
|
| 286 |
+
#include <ATen/ops/_sparse_mm_reduce_impl_meta.h>
|
| 287 |
+
#include <ATen/ops/_sparse_mm_reduce_impl_backward_meta.h>
|
| 288 |
+
#include <ATen/ops/_sparse_semi_structured_linear_meta.h>
|
| 289 |
+
#include <ATen/ops/_sparse_softmax_meta.h>
|
| 290 |
+
#include <ATen/ops/_sparse_softmax_backward_data_meta.h>
|
| 291 |
+
#include <ATen/ops/_sparse_sparse_matmul_meta.h>
|
| 292 |
+
#include <ATen/ops/_sparse_sum_meta.h>
|
| 293 |
+
#include <ATen/ops/_sparse_sum_backward_meta.h>
|
| 294 |
+
#include <ATen/ops/_spdiags_meta.h>
|
| 295 |
+
#include <ATen/ops/_stack_meta.h>
|
| 296 |
+
#include <ATen/ops/_standard_gamma_meta.h>
|
| 297 |
+
#include <ATen/ops/_standard_gamma_grad_meta.h>
|
| 298 |
+
#include <ATen/ops/_test_ambiguous_defaults_meta.h>
|
| 299 |
+
#include <ATen/ops/_test_autograd_multiple_dispatch_meta.h>
|
| 300 |
+
#include <ATen/ops/_test_autograd_multiple_dispatch_view_meta.h>
|
| 301 |
+
#include <ATen/ops/_test_autograd_multiple_dispatch_view_copy_meta.h>
|
| 302 |
+
#include <ATen/ops/_test_check_tensor_meta.h>
|
| 303 |
+
#include <ATen/ops/_test_functorch_fallback_meta.h>
|
| 304 |
+
#include <ATen/ops/_test_optional_filled_intlist_meta.h>
|
| 305 |
+
#include <ATen/ops/_test_optional_floatlist_meta.h>
|
| 306 |
+
#include <ATen/ops/_test_optional_intlist_meta.h>
|
| 307 |
+
#include <ATen/ops/_test_parallel_materialize_meta.h>
|
| 308 |
+
#include <ATen/ops/_test_serialization_subcmul_meta.h>
|
| 309 |
+
#include <ATen/ops/_test_string_default_meta.h>
|
| 310 |
+
#include <ATen/ops/_test_warn_in_autograd_meta.h>
|
| 311 |
+
#include <ATen/ops/_thnn_differentiable_gru_cell_backward_meta.h>
|
| 312 |
+
#include <ATen/ops/_thnn_differentiable_lstm_cell_backward_meta.h>
|
| 313 |
+
#include <ATen/ops/_thnn_fused_gru_cell_meta.h>
|
| 314 |
+
#include <ATen/ops/_thnn_fused_gru_cell_backward_meta.h>
|
| 315 |
+
#include <ATen/ops/_thnn_fused_lstm_cell_meta.h>
|
| 316 |
+
#include <ATen/ops/_thnn_fused_lstm_cell_backward_meta.h>
|
| 317 |
+
#include <ATen/ops/_thnn_fused_lstm_cell_backward_impl_meta.h>
|
| 318 |
+
#include <ATen/ops/_to_copy_meta.h>
|
| 319 |
+
#include <ATen/ops/_to_cpu_meta.h>
|
| 320 |
+
#include <ATen/ops/_to_dense_meta.h>
|
| 321 |
+
#include <ATen/ops/_to_sparse_meta.h>
|
| 322 |
+
#include <ATen/ops/_to_sparse_bsc_meta.h>
|
| 323 |
+
#include <ATen/ops/_to_sparse_bsr_meta.h>
|
| 324 |
+
#include <ATen/ops/_to_sparse_csc_meta.h>
|
| 325 |
+
#include <ATen/ops/_to_sparse_csr_meta.h>
|
| 326 |
+
#include <ATen/ops/_to_sparse_semi_structured_meta.h>
|
| 327 |
+
#include <ATen/ops/_transform_bias_rescale_qkv_meta.h>
|
| 328 |
+
#include <ATen/ops/_transformer_encoder_layer_fwd_meta.h>
|
| 329 |
+
#include <ATen/ops/_trilinear_meta.h>
|
| 330 |
+
#include <ATen/ops/_triton_multi_head_attention_meta.h>
|
| 331 |
+
#include <ATen/ops/_triton_scaled_dot_attention_meta.h>
|
| 332 |
+
#include <ATen/ops/_unique_meta.h>
|
| 333 |
+
#include <ATen/ops/_unique2_meta.h>
|
| 334 |
+
#include <ATen/ops/_unpack_dual_meta.h>
|
| 335 |
+
#include <ATen/ops/_unsafe_index_meta.h>
|
| 336 |
+
#include <ATen/ops/_unsafe_index_put_meta.h>
|
| 337 |
+
#include <ATen/ops/_unsafe_view_meta.h>
|
| 338 |
+
#include <ATen/ops/_upsample_bicubic2d_aa_meta.h>
|
| 339 |
+
#include <ATen/ops/_upsample_bicubic2d_aa_backward_meta.h>
|
| 340 |
+
#include <ATen/ops/_upsample_bilinear2d_aa_meta.h>
|
| 341 |
+
#include <ATen/ops/_upsample_bilinear2d_aa_backward_meta.h>
|
| 342 |
+
#include <ATen/ops/_upsample_nearest_exact1d_meta.h>
|
| 343 |
+
#include <ATen/ops/_upsample_nearest_exact1d_backward_meta.h>
|
| 344 |
+
#include <ATen/ops/_upsample_nearest_exact2d_meta.h>
|
| 345 |
+
#include <ATen/ops/_upsample_nearest_exact2d_backward_meta.h>
|
| 346 |
+
#include <ATen/ops/_upsample_nearest_exact3d_meta.h>
|
| 347 |
+
#include <ATen/ops/_upsample_nearest_exact3d_backward_meta.h>
|
| 348 |
+
#include <ATen/ops/_use_cudnn_ctc_loss_meta.h>
|
| 349 |
+
#include <ATen/ops/_use_cudnn_rnn_flatten_weight_meta.h>
|
| 350 |
+
#include <ATen/ops/_validate_compressed_sparse_indices_meta.h>
|
| 351 |
+
#include <ATen/ops/_validate_sparse_bsc_tensor_args_meta.h>
|
| 352 |
+
#include <ATen/ops/_validate_sparse_bsr_tensor_args_meta.h>
|
| 353 |
+
#include <ATen/ops/_validate_sparse_compressed_tensor_args_meta.h>
|
| 354 |
+
#include <ATen/ops/_validate_sparse_coo_tensor_args_meta.h>
|
| 355 |
+
#include <ATen/ops/_validate_sparse_csc_tensor_args_meta.h>
|
| 356 |
+
#include <ATen/ops/_validate_sparse_csr_tensor_args_meta.h>
|
| 357 |
+
#include <ATen/ops/_values_meta.h>
|
| 358 |
+
#include <ATen/ops/_values_copy_meta.h>
|
| 359 |
+
#include <ATen/ops/_version_meta.h>
|
| 360 |
+
#include <ATen/ops/_weight_int4pack_mm_meta.h>
|
| 361 |
+
#include <ATen/ops/_weight_int8pack_mm_meta.h>
|
| 362 |
+
#include <ATen/ops/_weight_norm_meta.h>
|
| 363 |
+
#include <ATen/ops/_weight_norm_differentiable_backward_meta.h>
|
| 364 |
+
#include <ATen/ops/_weight_norm_interface_meta.h>
|
| 365 |
+
#include <ATen/ops/_weight_norm_interface_backward_meta.h>
|
| 366 |
+
#include <ATen/ops/abs_meta.h>
|
| 367 |
+
#include <ATen/ops/absolute_meta.h>
|
| 368 |
+
#include <ATen/ops/acos_meta.h>
|
| 369 |
+
#include <ATen/ops/acosh_meta.h>
|
| 370 |
+
#include <ATen/ops/adaptive_avg_pool1d_meta.h>
|
| 371 |
+
#include <ATen/ops/adaptive_avg_pool2d_meta.h>
|
| 372 |
+
#include <ATen/ops/adaptive_avg_pool3d_meta.h>
|
| 373 |
+
#include <ATen/ops/adaptive_avg_pool3d_backward_meta.h>
|
| 374 |
+
#include <ATen/ops/adaptive_max_pool1d_meta.h>
|
| 375 |
+
#include <ATen/ops/adaptive_max_pool2d_meta.h>
|
| 376 |
+
#include <ATen/ops/adaptive_max_pool2d_backward_meta.h>
|
| 377 |
+
#include <ATen/ops/adaptive_max_pool3d_meta.h>
|
| 378 |
+
#include <ATen/ops/adaptive_max_pool3d_backward_meta.h>
|
| 379 |
+
#include <ATen/ops/add_meta.h>
|
| 380 |
+
#include <ATen/ops/addbmm_meta.h>
|
| 381 |
+
#include <ATen/ops/addcdiv_meta.h>
|
| 382 |
+
#include <ATen/ops/addcmul_meta.h>
|
| 383 |
+
#include <ATen/ops/addmm_meta.h>
|
| 384 |
+
#include <ATen/ops/addmv_meta.h>
|
| 385 |
+
#include <ATen/ops/addr_meta.h>
|
| 386 |
+
#include <ATen/ops/adjoint_meta.h>
|
| 387 |
+
#include <ATen/ops/affine_grid_generator_meta.h>
|
| 388 |
+
#include <ATen/ops/affine_grid_generator_backward_meta.h>
|
| 389 |
+
#include <ATen/ops/alias_meta.h>
|
| 390 |
+
#include <ATen/ops/alias_copy_meta.h>
|
| 391 |
+
#include <ATen/ops/align_as_meta.h>
|
| 392 |
+
#include <ATen/ops/align_tensors_meta.h>
|
| 393 |
+
#include <ATen/ops/align_to_meta.h>
|
| 394 |
+
#include <ATen/ops/all_meta.h>
|
| 395 |
+
#include <ATen/ops/allclose_meta.h>
|
| 396 |
+
#include <ATen/ops/alpha_dropout_meta.h>
|
| 397 |
+
#include <ATen/ops/amax_meta.h>
|
| 398 |
+
#include <ATen/ops/amin_meta.h>
|
| 399 |
+
#include <ATen/ops/aminmax_meta.h>
|
| 400 |
+
#include <ATen/ops/and_meta.h>
|
| 401 |
+
#include <ATen/ops/angle_meta.h>
|
| 402 |
+
#include <ATen/ops/any_meta.h>
|
| 403 |
+
#include <ATen/ops/arange_meta.h>
|
| 404 |
+
#include <ATen/ops/arccos_meta.h>
|
| 405 |
+
#include <ATen/ops/arccosh_meta.h>
|
| 406 |
+
#include <ATen/ops/arcsin_meta.h>
|
| 407 |
+
#include <ATen/ops/arcsinh_meta.h>
|
| 408 |
+
#include <ATen/ops/arctan_meta.h>
|
| 409 |
+
#include <ATen/ops/arctan2_meta.h>
|
| 410 |
+
#include <ATen/ops/arctanh_meta.h>
|
| 411 |
+
#include <ATen/ops/argmax_meta.h>
|
| 412 |
+
#include <ATen/ops/argmin_meta.h>
|
| 413 |
+
#include <ATen/ops/argsort_meta.h>
|
| 414 |
+
#include <ATen/ops/argwhere_meta.h>
|
| 415 |
+
#include <ATen/ops/as_strided_meta.h>
|
| 416 |
+
#include <ATen/ops/as_strided_copy_meta.h>
|
| 417 |
+
#include <ATen/ops/as_strided_scatter_meta.h>
|
| 418 |
+
#include <ATen/ops/asin_meta.h>
|
| 419 |
+
#include <ATen/ops/asinh_meta.h>
|
| 420 |
+
#include <ATen/ops/atan_meta.h>
|
| 421 |
+
#include <ATen/ops/atan2_meta.h>
|
| 422 |
+
#include <ATen/ops/atanh_meta.h>
|
| 423 |
+
#include <ATen/ops/atleast_1d_meta.h>
|
| 424 |
+
#include <ATen/ops/atleast_2d_meta.h>
|
| 425 |
+
#include <ATen/ops/atleast_3d_meta.h>
|
| 426 |
+
#include <ATen/ops/avg_pool1d_meta.h>
|
| 427 |
+
#include <ATen/ops/avg_pool2d_meta.h>
|
| 428 |
+
#include <ATen/ops/avg_pool2d_backward_meta.h>
|
| 429 |
+
#include <ATen/ops/avg_pool3d_meta.h>
|
| 430 |
+
#include <ATen/ops/avg_pool3d_backward_meta.h>
|
| 431 |
+
#include <ATen/ops/baddbmm_meta.h>
|
| 432 |
+
#include <ATen/ops/bartlett_window_meta.h>
|
| 433 |
+
#include <ATen/ops/batch_norm_meta.h>
|
| 434 |
+
#include <ATen/ops/batch_norm_backward_elemt_meta.h>
|
| 435 |
+
#include <ATen/ops/batch_norm_backward_reduce_meta.h>
|
| 436 |
+
#include <ATen/ops/batch_norm_elemt_meta.h>
|
| 437 |
+
#include <ATen/ops/batch_norm_gather_stats_meta.h>
|
| 438 |
+
#include <ATen/ops/batch_norm_gather_stats_with_counts_meta.h>
|
| 439 |
+
#include <ATen/ops/batch_norm_stats_meta.h>
|
| 440 |
+
#include <ATen/ops/batch_norm_update_stats_meta.h>
|
| 441 |
+
#include <ATen/ops/bernoulli_meta.h>
|
| 442 |
+
#include <ATen/ops/bilinear_meta.h>
|
| 443 |
+
#include <ATen/ops/binary_cross_entropy_meta.h>
|
| 444 |
+
#include <ATen/ops/binary_cross_entropy_backward_meta.h>
|
| 445 |
+
#include <ATen/ops/binary_cross_entropy_with_logits_meta.h>
|
| 446 |
+
#include <ATen/ops/bincount_meta.h>
|
| 447 |
+
#include <ATen/ops/binomial_meta.h>
|
| 448 |
+
#include <ATen/ops/bitwise_and_meta.h>
|
| 449 |
+
#include <ATen/ops/bitwise_left_shift_meta.h>
|
| 450 |
+
#include <ATen/ops/bitwise_not_meta.h>
|
| 451 |
+
#include <ATen/ops/bitwise_or_meta.h>
|
| 452 |
+
#include <ATen/ops/bitwise_right_shift_meta.h>
|
| 453 |
+
#include <ATen/ops/bitwise_xor_meta.h>
|
| 454 |
+
#include <ATen/ops/blackman_window_meta.h>
|
| 455 |
+
#include <ATen/ops/block_diag_meta.h>
|
| 456 |
+
#include <ATen/ops/bmm_meta.h>
|
| 457 |
+
#include <ATen/ops/broadcast_tensors_meta.h>
|
| 458 |
+
#include <ATen/ops/broadcast_to_meta.h>
|
| 459 |
+
#include <ATen/ops/bucketize_meta.h>
|
| 460 |
+
#include <ATen/ops/can_cast_meta.h>
|
| 461 |
+
#include <ATen/ops/cartesian_prod_meta.h>
|
| 462 |
+
#include <ATen/ops/cat_meta.h>
|
| 463 |
+
#include <ATen/ops/cauchy_meta.h>
|
| 464 |
+
#include <ATen/ops/ccol_indices_meta.h>
|
| 465 |
+
#include <ATen/ops/ccol_indices_copy_meta.h>
|
| 466 |
+
#include <ATen/ops/cdist_meta.h>
|
| 467 |
+
#include <ATen/ops/ceil_meta.h>
|
| 468 |
+
#include <ATen/ops/celu_meta.h>
|
| 469 |
+
#include <ATen/ops/chain_matmul_meta.h>
|
| 470 |
+
#include <ATen/ops/chalf_meta.h>
|
| 471 |
+
#include <ATen/ops/channel_shuffle_meta.h>
|
| 472 |
+
#include <ATen/ops/cholesky_meta.h>
|
| 473 |
+
#include <ATen/ops/cholesky_inverse_meta.h>
|
| 474 |
+
#include <ATen/ops/cholesky_solve_meta.h>
|
| 475 |
+
#include <ATen/ops/choose_qparams_optimized_meta.h>
|
| 476 |
+
#include <ATen/ops/chunk_meta.h>
|
| 477 |
+
#include <ATen/ops/clamp_meta.h>
|
| 478 |
+
#include <ATen/ops/clamp_max_meta.h>
|
| 479 |
+
#include <ATen/ops/clamp_min_meta.h>
|
| 480 |
+
#include <ATen/ops/clip_meta.h>
|
| 481 |
+
#include <ATen/ops/clone_meta.h>
|
| 482 |
+
#include <ATen/ops/coalesce_meta.h>
|
| 483 |
+
#include <ATen/ops/col2im_meta.h>
|
| 484 |
+
#include <ATen/ops/col_indices_meta.h>
|
| 485 |
+
#include <ATen/ops/col_indices_copy_meta.h>
|
| 486 |
+
#include <ATen/ops/column_stack_meta.h>
|
| 487 |
+
#include <ATen/ops/combinations_meta.h>
|
| 488 |
+
#include <ATen/ops/complex_meta.h>
|
| 489 |
+
#include <ATen/ops/concat_meta.h>
|
| 490 |
+
#include <ATen/ops/concatenate_meta.h>
|
| 491 |
+
#include <ATen/ops/conj_meta.h>
|
| 492 |
+
#include <ATen/ops/conj_physical_meta.h>
|
| 493 |
+
#include <ATen/ops/constant_pad_nd_meta.h>
|
| 494 |
+
#include <ATen/ops/contiguous_meta.h>
|
| 495 |
+
#include <ATen/ops/conv1d_meta.h>
|
| 496 |
+
#include <ATen/ops/conv2d_meta.h>
|
| 497 |
+
#include <ATen/ops/conv3d_meta.h>
|
| 498 |
+
#include <ATen/ops/conv_depthwise3d_meta.h>
|
| 499 |
+
#include <ATen/ops/conv_tbc_meta.h>
|
| 500 |
+
#include <ATen/ops/conv_tbc_backward_meta.h>
|
| 501 |
+
#include <ATen/ops/conv_transpose1d_meta.h>
|
| 502 |
+
#include <ATen/ops/conv_transpose2d_meta.h>
|
| 503 |
+
#include <ATen/ops/conv_transpose3d_meta.h>
|
| 504 |
+
#include <ATen/ops/convolution_meta.h>
|
| 505 |
+
#include <ATen/ops/convolution_backward_meta.h>
|
| 506 |
+
#include <ATen/ops/convolution_backward_overrideable_meta.h>
|
| 507 |
+
#include <ATen/ops/convolution_overrideable_meta.h>
|
| 508 |
+
#include <ATen/ops/copy_meta.h>
|
| 509 |
+
#include <ATen/ops/copy_sparse_to_sparse_meta.h>
|
| 510 |
+
#include <ATen/ops/copysign_meta.h>
|
| 511 |
+
#include <ATen/ops/corrcoef_meta.h>
|
| 512 |
+
#include <ATen/ops/cos_meta.h>
|
| 513 |
+
#include <ATen/ops/cosh_meta.h>
|
| 514 |
+
#include <ATen/ops/cosine_embedding_loss_meta.h>
|
| 515 |
+
#include <ATen/ops/cosine_similarity_meta.h>
|
| 516 |
+
#include <ATen/ops/count_nonzero_meta.h>
|
| 517 |
+
#include <ATen/ops/cov_meta.h>
|
| 518 |
+
#include <ATen/ops/cross_meta.h>
|
| 519 |
+
#include <ATen/ops/cross_entropy_loss_meta.h>
|
| 520 |
+
#include <ATen/ops/crow_indices_meta.h>
|
| 521 |
+
#include <ATen/ops/crow_indices_copy_meta.h>
|
| 522 |
+
#include <ATen/ops/ctc_loss_meta.h>
|
| 523 |
+
#include <ATen/ops/cudnn_affine_grid_generator_meta.h>
|
| 524 |
+
#include <ATen/ops/cudnn_affine_grid_generator_backward_meta.h>
|
| 525 |
+
#include <ATen/ops/cudnn_batch_norm_meta.h>
|
| 526 |
+
#include <ATen/ops/cudnn_batch_norm_backward_meta.h>
|
| 527 |
+
#include <ATen/ops/cudnn_convolution_meta.h>
|
| 528 |
+
#include <ATen/ops/cudnn_convolution_add_relu_meta.h>
|
| 529 |
+
#include <ATen/ops/cudnn_convolution_relu_meta.h>
|
| 530 |
+
#include <ATen/ops/cudnn_convolution_transpose_meta.h>
|
| 531 |
+
#include <ATen/ops/cudnn_grid_sampler_meta.h>
|
| 532 |
+
#include <ATen/ops/cudnn_grid_sampler_backward_meta.h>
|
| 533 |
+
#include <ATen/ops/cudnn_is_acceptable_meta.h>
|
| 534 |
+
#include <ATen/ops/cummax_meta.h>
|
| 535 |
+
#include <ATen/ops/cummaxmin_backward_meta.h>
|
| 536 |
+
#include <ATen/ops/cummin_meta.h>
|
| 537 |
+
#include <ATen/ops/cumprod_meta.h>
|
| 538 |
+
#include <ATen/ops/cumprod_backward_meta.h>
|
| 539 |
+
#include <ATen/ops/cumsum_meta.h>
|
| 540 |
+
#include <ATen/ops/cumulative_trapezoid_meta.h>
|
| 541 |
+
#include <ATen/ops/data_meta.h>
|
| 542 |
+
#include <ATen/ops/deg2rad_meta.h>
|
| 543 |
+
#include <ATen/ops/dense_dim_meta.h>
|
| 544 |
+
#include <ATen/ops/dequantize_meta.h>
|
| 545 |
+
#include <ATen/ops/det_meta.h>
|
| 546 |
+
#include <ATen/ops/detach_meta.h>
|
| 547 |
+
#include <ATen/ops/detach_copy_meta.h>
|
| 548 |
+
#include <ATen/ops/diag_meta.h>
|
| 549 |
+
#include <ATen/ops/diag_embed_meta.h>
|
| 550 |
+
#include <ATen/ops/diagflat_meta.h>
|
| 551 |
+
#include <ATen/ops/diagonal_meta.h>
|
| 552 |
+
#include <ATen/ops/diagonal_backward_meta.h>
|
| 553 |
+
#include <ATen/ops/diagonal_copy_meta.h>
|
| 554 |
+
#include <ATen/ops/diagonal_scatter_meta.h>
|
| 555 |
+
#include <ATen/ops/diff_meta.h>
|
| 556 |
+
#include <ATen/ops/digamma_meta.h>
|
| 557 |
+
#include <ATen/ops/dist_meta.h>
|
| 558 |
+
#include <ATen/ops/div_meta.h>
|
| 559 |
+
#include <ATen/ops/divide_meta.h>
|
| 560 |
+
#include <ATen/ops/dot_meta.h>
|
| 561 |
+
#include <ATen/ops/dropout_meta.h>
|
| 562 |
+
#include <ATen/ops/dsplit_meta.h>
|
| 563 |
+
#include <ATen/ops/dstack_meta.h>
|
| 564 |
+
#include <ATen/ops/einsum_meta.h>
|
| 565 |
+
#include <ATen/ops/elu_meta.h>
|
| 566 |
+
#include <ATen/ops/elu_backward_meta.h>
|
| 567 |
+
#include <ATen/ops/embedding_meta.h>
|
| 568 |
+
#include <ATen/ops/embedding_backward_meta.h>
|
| 569 |
+
#include <ATen/ops/embedding_bag_meta.h>
|
| 570 |
+
#include <ATen/ops/embedding_dense_backward_meta.h>
|
| 571 |
+
#include <ATen/ops/embedding_renorm_meta.h>
|
| 572 |
+
#include <ATen/ops/embedding_sparse_backward_meta.h>
|
| 573 |
+
#include <ATen/ops/empty_meta.h>
|
| 574 |
+
#include <ATen/ops/empty_like_meta.h>
|
| 575 |
+
#include <ATen/ops/empty_permuted_meta.h>
|
| 576 |
+
#include <ATen/ops/empty_quantized_meta.h>
|
| 577 |
+
#include <ATen/ops/empty_strided_meta.h>
|
| 578 |
+
#include <ATen/ops/eq_meta.h>
|
| 579 |
+
#include <ATen/ops/equal_meta.h>
|
| 580 |
+
#include <ATen/ops/erf_meta.h>
|
| 581 |
+
#include <ATen/ops/erfc_meta.h>
|
| 582 |
+
#include <ATen/ops/erfinv_meta.h>
|
| 583 |
+
#include <ATen/ops/exp_meta.h>
|
| 584 |
+
#include <ATen/ops/exp2_meta.h>
|
| 585 |
+
#include <ATen/ops/expand_meta.h>
|
| 586 |
+
#include <ATen/ops/expand_as_meta.h>
|
| 587 |
+
#include <ATen/ops/expand_copy_meta.h>
|
| 588 |
+
#include <ATen/ops/expm1_meta.h>
|
| 589 |
+
#include <ATen/ops/exponential_meta.h>
|
| 590 |
+
#include <ATen/ops/eye_meta.h>
|
| 591 |
+
#include <ATen/ops/fake_quantize_per_channel_affine_meta.h>
|
| 592 |
+
#include <ATen/ops/fake_quantize_per_channel_affine_cachemask_meta.h>
|
| 593 |
+
#include <ATen/ops/fake_quantize_per_channel_affine_cachemask_backward_meta.h>
|
| 594 |
+
#include <ATen/ops/fake_quantize_per_tensor_affine_meta.h>
|
| 595 |
+
#include <ATen/ops/fake_quantize_per_tensor_affine_cachemask_meta.h>
|
| 596 |
+
#include <ATen/ops/fake_quantize_per_tensor_affine_cachemask_backward_meta.h>
|
| 597 |
+
#include <ATen/ops/fbgemm_linear_fp16_weight_meta.h>
|
| 598 |
+
#include <ATen/ops/fbgemm_linear_fp16_weight_fp32_activation_meta.h>
|
| 599 |
+
#include <ATen/ops/fbgemm_linear_int8_weight_meta.h>
|
| 600 |
+
#include <ATen/ops/fbgemm_linear_int8_weight_fp32_activation_meta.h>
|
| 601 |
+
#include <ATen/ops/fbgemm_linear_quantize_weight_meta.h>
|
| 602 |
+
#include <ATen/ops/fbgemm_pack_gemm_matrix_fp16_meta.h>
|
| 603 |
+
#include <ATen/ops/fbgemm_pack_quantized_matrix_meta.h>
|
| 604 |
+
#include <ATen/ops/feature_alpha_dropout_meta.h>
|
| 605 |
+
#include <ATen/ops/feature_dropout_meta.h>
|
| 606 |
+
#include <ATen/ops/fft_fft_meta.h>
|
| 607 |
+
#include <ATen/ops/fft_fft2_meta.h>
|
| 608 |
+
#include <ATen/ops/fft_fftfreq_meta.h>
|
| 609 |
+
#include <ATen/ops/fft_fftn_meta.h>
|
| 610 |
+
#include <ATen/ops/fft_fftshift_meta.h>
|
| 611 |
+
#include <ATen/ops/fft_hfft_meta.h>
|
| 612 |
+
#include <ATen/ops/fft_hfft2_meta.h>
|
| 613 |
+
#include <ATen/ops/fft_hfftn_meta.h>
|
| 614 |
+
#include <ATen/ops/fft_ifft_meta.h>
|
| 615 |
+
#include <ATen/ops/fft_ifft2_meta.h>
|
| 616 |
+
#include <ATen/ops/fft_ifftn_meta.h>
|
| 617 |
+
#include <ATen/ops/fft_ifftshift_meta.h>
|
| 618 |
+
#include <ATen/ops/fft_ihfft_meta.h>
|
| 619 |
+
#include <ATen/ops/fft_ihfft2_meta.h>
|
| 620 |
+
#include <ATen/ops/fft_ihfftn_meta.h>
|
| 621 |
+
#include <ATen/ops/fft_irfft_meta.h>
|
| 622 |
+
#include <ATen/ops/fft_irfft2_meta.h>
|
| 623 |
+
#include <ATen/ops/fft_irfftn_meta.h>
|
| 624 |
+
#include <ATen/ops/fft_rfft_meta.h>
|
| 625 |
+
#include <ATen/ops/fft_rfft2_meta.h>
|
| 626 |
+
#include <ATen/ops/fft_rfftfreq_meta.h>
|
| 627 |
+
#include <ATen/ops/fft_rfftn_meta.h>
|
| 628 |
+
#include <ATen/ops/fill_meta.h>
|
| 629 |
+
#include <ATen/ops/fill_diagonal_meta.h>
|
| 630 |
+
#include <ATen/ops/fix_meta.h>
|
| 631 |
+
#include <ATen/ops/flatten_meta.h>
|
| 632 |
+
#include <ATen/ops/flatten_dense_tensors_meta.h>
|
| 633 |
+
#include <ATen/ops/flip_meta.h>
|
| 634 |
+
#include <ATen/ops/fliplr_meta.h>
|
| 635 |
+
#include <ATen/ops/flipud_meta.h>
|
| 636 |
+
#include <ATen/ops/float_power_meta.h>
|
| 637 |
+
#include <ATen/ops/floor_meta.h>
|
| 638 |
+
#include <ATen/ops/floor_divide_meta.h>
|
| 639 |
+
#include <ATen/ops/fmax_meta.h>
|
| 640 |
+
#include <ATen/ops/fmin_meta.h>
|
| 641 |
+
#include <ATen/ops/fmod_meta.h>
|
| 642 |
+
#include <ATen/ops/frac_meta.h>
|
| 643 |
+
#include <ATen/ops/fractional_max_pool2d_meta.h>
|
| 644 |
+
#include <ATen/ops/fractional_max_pool2d_backward_meta.h>
|
| 645 |
+
#include <ATen/ops/fractional_max_pool3d_meta.h>
|
| 646 |
+
#include <ATen/ops/fractional_max_pool3d_backward_meta.h>
|
| 647 |
+
#include <ATen/ops/frexp_meta.h>
|
| 648 |
+
#include <ATen/ops/frobenius_norm_meta.h>
|
| 649 |
+
#include <ATen/ops/from_file_meta.h>
|
| 650 |
+
#include <ATen/ops/full_meta.h>
|
| 651 |
+
#include <ATen/ops/full_like_meta.h>
|
| 652 |
+
#include <ATen/ops/fused_moving_avg_obs_fake_quant_meta.h>
|
| 653 |
+
#include <ATen/ops/gather_meta.h>
|
| 654 |
+
#include <ATen/ops/gather_backward_meta.h>
|
| 655 |
+
#include <ATen/ops/gcd_meta.h>
|
| 656 |
+
#include <ATen/ops/ge_meta.h>
|
| 657 |
+
#include <ATen/ops/gelu_meta.h>
|
| 658 |
+
#include <ATen/ops/gelu_backward_meta.h>
|
| 659 |
+
#include <ATen/ops/geometric_meta.h>
|
| 660 |
+
#include <ATen/ops/geqrf_meta.h>
|
| 661 |
+
#include <ATen/ops/ger_meta.h>
|
| 662 |
+
#include <ATen/ops/glu_meta.h>
|
| 663 |
+
#include <ATen/ops/glu_backward_meta.h>
|
| 664 |
+
#include <ATen/ops/glu_backward_jvp_meta.h>
|
| 665 |
+
#include <ATen/ops/glu_jvp_meta.h>
|
| 666 |
+
#include <ATen/ops/gradient_meta.h>
|
| 667 |
+
#include <ATen/ops/greater_meta.h>
|
| 668 |
+
#include <ATen/ops/greater_equal_meta.h>
|
| 669 |
+
#include <ATen/ops/grid_sampler_meta.h>
|
| 670 |
+
#include <ATen/ops/grid_sampler_2d_meta.h>
|
| 671 |
+
#include <ATen/ops/grid_sampler_2d_backward_meta.h>
|
| 672 |
+
#include <ATen/ops/grid_sampler_3d_meta.h>
|
| 673 |
+
#include <ATen/ops/grid_sampler_3d_backward_meta.h>
|
| 674 |
+
#include <ATen/ops/group_norm_meta.h>
|
| 675 |
+
#include <ATen/ops/gru_meta.h>
|
| 676 |
+
#include <ATen/ops/gru_cell_meta.h>
|
| 677 |
+
#include <ATen/ops/gt_meta.h>
|
| 678 |
+
#include <ATen/ops/hamming_window_meta.h>
|
| 679 |
+
#include <ATen/ops/hann_window_meta.h>
|
| 680 |
+
#include <ATen/ops/hardshrink_meta.h>
|
| 681 |
+
#include <ATen/ops/hardshrink_backward_meta.h>
|
| 682 |
+
#include <ATen/ops/hardsigmoid_meta.h>
|
| 683 |
+
#include <ATen/ops/hardsigmoid_backward_meta.h>
|
| 684 |
+
#include <ATen/ops/hardswish_meta.h>
|
| 685 |
+
#include <ATen/ops/hardswish_backward_meta.h>
|
| 686 |
+
#include <ATen/ops/hardtanh_meta.h>
|
| 687 |
+
#include <ATen/ops/hardtanh_backward_meta.h>
|
| 688 |
+
#include <ATen/ops/heaviside_meta.h>
|
| 689 |
+
#include <ATen/ops/hinge_embedding_loss_meta.h>
|
| 690 |
+
#include <ATen/ops/histc_meta.h>
|
| 691 |
+
#include <ATen/ops/histogram_meta.h>
|
| 692 |
+
#include <ATen/ops/histogramdd_meta.h>
|
| 693 |
+
#include <ATen/ops/hsplit_meta.h>
|
| 694 |
+
#include <ATen/ops/hspmm_meta.h>
|
| 695 |
+
#include <ATen/ops/hstack_meta.h>
|
| 696 |
+
#include <ATen/ops/huber_loss_meta.h>
|
| 697 |
+
#include <ATen/ops/huber_loss_backward_meta.h>
|
| 698 |
+
#include <ATen/ops/hypot_meta.h>
|
| 699 |
+
#include <ATen/ops/i0_meta.h>
|
| 700 |
+
#include <ATen/ops/igamma_meta.h>
|
| 701 |
+
#include <ATen/ops/igammac_meta.h>
|
| 702 |
+
#include <ATen/ops/im2col_meta.h>
|
| 703 |
+
#include <ATen/ops/imag_meta.h>
|
| 704 |
+
#include <ATen/ops/index_meta.h>
|
| 705 |
+
#include <ATen/ops/index_add_meta.h>
|
| 706 |
+
#include <ATen/ops/index_copy_meta.h>
|
| 707 |
+
#include <ATen/ops/index_fill_meta.h>
|
| 708 |
+
#include <ATen/ops/index_put_meta.h>
|
| 709 |
+
#include <ATen/ops/index_reduce_meta.h>
|
| 710 |
+
#include <ATen/ops/index_select_meta.h>
|
| 711 |
+
#include <ATen/ops/index_select_backward_meta.h>
|
| 712 |
+
#include <ATen/ops/indices_meta.h>
|
| 713 |
+
#include <ATen/ops/indices_copy_meta.h>
|
| 714 |
+
#include <ATen/ops/infinitely_differentiable_gelu_backward_meta.h>
|
| 715 |
+
#include <ATen/ops/inner_meta.h>
|
| 716 |
+
#include <ATen/ops/instance_norm_meta.h>
|
| 717 |
+
#include <ATen/ops/int_repr_meta.h>
|
| 718 |
+
#include <ATen/ops/inverse_meta.h>
|
| 719 |
+
#include <ATen/ops/is_coalesced_meta.h>
|
| 720 |
+
#include <ATen/ops/is_complex_meta.h>
|
| 721 |
+
#include <ATen/ops/is_conj_meta.h>
|
| 722 |
+
#include <ATen/ops/is_distributed_meta.h>
|
| 723 |
+
#include <ATen/ops/is_floating_point_meta.h>
|
| 724 |
+
#include <ATen/ops/is_inference_meta.h>
|
| 725 |
+
#include <ATen/ops/is_leaf_meta.h>
|
| 726 |
+
#include <ATen/ops/is_neg_meta.h>
|
| 727 |
+
#include <ATen/ops/is_nonzero_meta.h>
|
| 728 |
+
#include <ATen/ops/is_pinned_meta.h>
|
| 729 |
+
#include <ATen/ops/is_same_size_meta.h>
|
| 730 |
+
#include <ATen/ops/is_set_to_meta.h>
|
| 731 |
+
#include <ATen/ops/is_signed_meta.h>
|
| 732 |
+
#include <ATen/ops/is_vulkan_available_meta.h>
|
| 733 |
+
#include <ATen/ops/isclose_meta.h>
|
| 734 |
+
#include <ATen/ops/isfinite_meta.h>
|
| 735 |
+
#include <ATen/ops/isin_meta.h>
|
| 736 |
+
#include <ATen/ops/isinf_meta.h>
|
| 737 |
+
#include <ATen/ops/isnan_meta.h>
|
| 738 |
+
#include <ATen/ops/isneginf_meta.h>
|
| 739 |
+
#include <ATen/ops/isposinf_meta.h>
|
| 740 |
+
#include <ATen/ops/isreal_meta.h>
|
| 741 |
+
#include <ATen/ops/istft_meta.h>
|
| 742 |
+
#include <ATen/ops/item_meta.h>
|
| 743 |
+
#include <ATen/ops/kaiser_window_meta.h>
|
| 744 |
+
#include <ATen/ops/kl_div_meta.h>
|
| 745 |
+
#include <ATen/ops/kron_meta.h>
|
| 746 |
+
#include <ATen/ops/kthvalue_meta.h>
|
| 747 |
+
#include <ATen/ops/l1_loss_meta.h>
|
| 748 |
+
#include <ATen/ops/layer_norm_meta.h>
|
| 749 |
+
#include <ATen/ops/lcm_meta.h>
|
| 750 |
+
#include <ATen/ops/ldexp_meta.h>
|
| 751 |
+
#include <ATen/ops/le_meta.h>
|
| 752 |
+
#include <ATen/ops/leaky_relu_meta.h>
|
| 753 |
+
#include <ATen/ops/leaky_relu_backward_meta.h>
|
| 754 |
+
#include <ATen/ops/lerp_meta.h>
|
| 755 |
+
#include <ATen/ops/less_meta.h>
|
| 756 |
+
#include <ATen/ops/less_equal_meta.h>
|
| 757 |
+
#include <ATen/ops/lgamma_meta.h>
|
| 758 |
+
#include <ATen/ops/lift_meta.h>
|
| 759 |
+
#include <ATen/ops/lift_fresh_meta.h>
|
| 760 |
+
#include <ATen/ops/lift_fresh_copy_meta.h>
|
| 761 |
+
#include <ATen/ops/linalg_cholesky_meta.h>
|
| 762 |
+
#include <ATen/ops/linalg_cholesky_ex_meta.h>
|
| 763 |
+
#include <ATen/ops/linalg_cond_meta.h>
|
| 764 |
+
#include <ATen/ops/linalg_cross_meta.h>
|
| 765 |
+
#include <ATen/ops/linalg_det_meta.h>
|
| 766 |
+
#include <ATen/ops/linalg_diagonal_meta.h>
|
| 767 |
+
#include <ATen/ops/linalg_eig_meta.h>
|
| 768 |
+
#include <ATen/ops/linalg_eigh_meta.h>
|
| 769 |
+
#include <ATen/ops/linalg_eigvals_meta.h>
|
| 770 |
+
#include <ATen/ops/linalg_eigvalsh_meta.h>
|
| 771 |
+
#include <ATen/ops/linalg_householder_product_meta.h>
|
| 772 |
+
#include <ATen/ops/linalg_inv_meta.h>
|
| 773 |
+
#include <ATen/ops/linalg_inv_ex_meta.h>
|
| 774 |
+
#include <ATen/ops/linalg_ldl_factor_meta.h>
|
| 775 |
+
#include <ATen/ops/linalg_ldl_factor_ex_meta.h>
|
| 776 |
+
#include <ATen/ops/linalg_ldl_solve_meta.h>
|
| 777 |
+
#include <ATen/ops/linalg_lstsq_meta.h>
|
| 778 |
+
#include <ATen/ops/linalg_lu_meta.h>
|
| 779 |
+
#include <ATen/ops/linalg_lu_factor_meta.h>
|
| 780 |
+
#include <ATen/ops/linalg_lu_factor_ex_meta.h>
|
| 781 |
+
#include <ATen/ops/linalg_lu_solve_meta.h>
|
| 782 |
+
#include <ATen/ops/linalg_matmul_meta.h>
|
| 783 |
+
#include <ATen/ops/linalg_matrix_exp_meta.h>
|
| 784 |
+
#include <ATen/ops/linalg_matrix_norm_meta.h>
|
| 785 |
+
#include <ATen/ops/linalg_matrix_power_meta.h>
|
| 786 |
+
#include <ATen/ops/linalg_matrix_rank_meta.h>
|
| 787 |
+
#include <ATen/ops/linalg_multi_dot_meta.h>
|
| 788 |
+
#include <ATen/ops/linalg_norm_meta.h>
|
| 789 |
+
#include <ATen/ops/linalg_pinv_meta.h>
|
| 790 |
+
#include <ATen/ops/linalg_qr_meta.h>
|
| 791 |
+
#include <ATen/ops/linalg_slogdet_meta.h>
|
| 792 |
+
#include <ATen/ops/linalg_solve_meta.h>
|
| 793 |
+
#include <ATen/ops/linalg_solve_ex_meta.h>
|
| 794 |
+
#include <ATen/ops/linalg_solve_triangular_meta.h>
|
| 795 |
+
#include <ATen/ops/linalg_svd_meta.h>
|
| 796 |
+
#include <ATen/ops/linalg_svdvals_meta.h>
|
| 797 |
+
#include <ATen/ops/linalg_tensorinv_meta.h>
|
| 798 |
+
#include <ATen/ops/linalg_tensorsolve_meta.h>
|
| 799 |
+
#include <ATen/ops/linalg_vander_meta.h>
|
| 800 |
+
#include <ATen/ops/linalg_vecdot_meta.h>
|
| 801 |
+
#include <ATen/ops/linalg_vector_norm_meta.h>
|
| 802 |
+
#include <ATen/ops/linear_meta.h>
|
| 803 |
+
#include <ATen/ops/linear_backward_meta.h>
|
| 804 |
+
#include <ATen/ops/linspace_meta.h>
|
| 805 |
+
#include <ATen/ops/log_meta.h>
|
| 806 |
+
#include <ATen/ops/log10_meta.h>
|
| 807 |
+
#include <ATen/ops/log1p_meta.h>
|
| 808 |
+
#include <ATen/ops/log2_meta.h>
|
| 809 |
+
#include <ATen/ops/log_normal_meta.h>
|
| 810 |
+
#include <ATen/ops/log_sigmoid_meta.h>
|
| 811 |
+
#include <ATen/ops/log_sigmoid_backward_meta.h>
|
| 812 |
+
#include <ATen/ops/log_sigmoid_forward_meta.h>
|
| 813 |
+
#include <ATen/ops/log_softmax_meta.h>
|
| 814 |
+
#include <ATen/ops/logaddexp_meta.h>
|
| 815 |
+
#include <ATen/ops/logaddexp2_meta.h>
|
| 816 |
+
#include <ATen/ops/logcumsumexp_meta.h>
|
| 817 |
+
#include <ATen/ops/logdet_meta.h>
|
| 818 |
+
#include <ATen/ops/logical_and_meta.h>
|
| 819 |
+
#include <ATen/ops/logical_not_meta.h>
|
| 820 |
+
#include <ATen/ops/logical_or_meta.h>
|
| 821 |
+
#include <ATen/ops/logical_xor_meta.h>
|
| 822 |
+
#include <ATen/ops/logit_meta.h>
|
| 823 |
+
#include <ATen/ops/logit_backward_meta.h>
|
| 824 |
+
#include <ATen/ops/logspace_meta.h>
|
| 825 |
+
#include <ATen/ops/logsumexp_meta.h>
|
| 826 |
+
#include <ATen/ops/lshift_meta.h>
|
| 827 |
+
#include <ATen/ops/lstm_meta.h>
|
| 828 |
+
#include <ATen/ops/lstm_cell_meta.h>
|
| 829 |
+
#include <ATen/ops/lstm_mps_backward_meta.h>
|
| 830 |
+
#include <ATen/ops/lt_meta.h>
|
| 831 |
+
#include <ATen/ops/lu_solve_meta.h>
|
| 832 |
+
#include <ATen/ops/lu_unpack_meta.h>
|
| 833 |
+
#include <ATen/ops/mH_meta.h>
|
| 834 |
+
#include <ATen/ops/mT_meta.h>
|
| 835 |
+
#include <ATen/ops/margin_ranking_loss_meta.h>
|
| 836 |
+
#include <ATen/ops/masked_fill_meta.h>
|
| 837 |
+
#include <ATen/ops/masked_scatter_meta.h>
|
| 838 |
+
#include <ATen/ops/masked_scatter_backward_meta.h>
|
| 839 |
+
#include <ATen/ops/masked_select_meta.h>
|
| 840 |
+
#include <ATen/ops/masked_select_backward_meta.h>
|
| 841 |
+
#include <ATen/ops/matmul_meta.h>
|
| 842 |
+
#include <ATen/ops/matmul_backward_meta.h>
|
| 843 |
+
#include <ATen/ops/matrix_H_meta.h>
|
| 844 |
+
#include <ATen/ops/matrix_exp_meta.h>
|
| 845 |
+
#include <ATen/ops/matrix_exp_backward_meta.h>
|
| 846 |
+
#include <ATen/ops/matrix_power_meta.h>
|
| 847 |
+
#include <ATen/ops/max_meta.h>
|
| 848 |
+
#include <ATen/ops/max_pool1d_meta.h>
|
| 849 |
+
#include <ATen/ops/max_pool1d_with_indices_meta.h>
|
| 850 |
+
#include <ATen/ops/max_pool2d_meta.h>
|
| 851 |
+
#include <ATen/ops/max_pool2d_backward_meta.h>
|
| 852 |
+
#include <ATen/ops/max_pool2d_with_indices_meta.h>
|
| 853 |
+
#include <ATen/ops/max_pool2d_with_indices_backward_meta.h>
|
| 854 |
+
#include <ATen/ops/max_pool3d_meta.h>
|
| 855 |
+
#include <ATen/ops/max_pool3d_with_indices_meta.h>
|
| 856 |
+
#include <ATen/ops/max_pool3d_with_indices_backward_meta.h>
|
| 857 |
+
#include <ATen/ops/max_unpool2d_meta.h>
|
| 858 |
+
#include <ATen/ops/max_unpool3d_meta.h>
|
| 859 |
+
#include <ATen/ops/maximum_meta.h>
|
| 860 |
+
#include <ATen/ops/mean_meta.h>
|
| 861 |
+
#include <ATen/ops/median_meta.h>
|
| 862 |
+
#include <ATen/ops/meshgrid_meta.h>
|
| 863 |
+
#include <ATen/ops/min_meta.h>
|
| 864 |
+
#include <ATen/ops/minimum_meta.h>
|
| 865 |
+
#include <ATen/ops/miopen_batch_norm_meta.h>
|
| 866 |
+
#include <ATen/ops/miopen_batch_norm_backward_meta.h>
|
| 867 |
+
#include <ATen/ops/miopen_convolution_meta.h>
|
| 868 |
+
#include <ATen/ops/miopen_convolution_add_relu_meta.h>
|
| 869 |
+
#include <ATen/ops/miopen_convolution_relu_meta.h>
|
| 870 |
+
#include <ATen/ops/miopen_convolution_transpose_meta.h>
|
| 871 |
+
#include <ATen/ops/miopen_depthwise_convolution_meta.h>
|
| 872 |
+
#include <ATen/ops/miopen_rnn_meta.h>
|
| 873 |
+
#include <ATen/ops/miopen_rnn_backward_meta.h>
|
| 874 |
+
#include <ATen/ops/mish_meta.h>
|
| 875 |
+
#include <ATen/ops/mish_backward_meta.h>
|
| 876 |
+
#include <ATen/ops/mkldnn_adaptive_avg_pool2d_meta.h>
|
| 877 |
+
#include <ATen/ops/mkldnn_adaptive_avg_pool2d_backward_meta.h>
|
| 878 |
+
#include <ATen/ops/mkldnn_convolution_meta.h>
|
| 879 |
+
#include <ATen/ops/mkldnn_linear_meta.h>
|
| 880 |
+
#include <ATen/ops/mkldnn_linear_backward_meta.h>
|
| 881 |
+
#include <ATen/ops/mkldnn_linear_backward_input_meta.h>
|
| 882 |
+
#include <ATen/ops/mkldnn_linear_backward_weights_meta.h>
|
| 883 |
+
#include <ATen/ops/mkldnn_max_pool2d_meta.h>
|
| 884 |
+
#include <ATen/ops/mkldnn_max_pool2d_backward_meta.h>
|
| 885 |
+
#include <ATen/ops/mkldnn_max_pool3d_meta.h>
|
| 886 |
+
#include <ATen/ops/mkldnn_max_pool3d_backward_meta.h>
|
| 887 |
+
#include <ATen/ops/mkldnn_reorder_conv2d_weight_meta.h>
|
| 888 |
+
#include <ATen/ops/mkldnn_reorder_conv3d_weight_meta.h>
|
| 889 |
+
#include <ATen/ops/mkldnn_rnn_layer_meta.h>
|
| 890 |
+
#include <ATen/ops/mkldnn_rnn_layer_backward_meta.h>
|
| 891 |
+
#include <ATen/ops/mm_meta.h>
|
| 892 |
+
#include <ATen/ops/mode_meta.h>
|
| 893 |
+
#include <ATen/ops/moveaxis_meta.h>
|
| 894 |
+
#include <ATen/ops/movedim_meta.h>
|
| 895 |
+
#include <ATen/ops/mps_convolution_backward_meta.h>
|
| 896 |
+
#include <ATen/ops/mps_convolution_transpose_backward_meta.h>
|
| 897 |
+
#include <ATen/ops/mse_loss_meta.h>
|
| 898 |
+
#include <ATen/ops/mse_loss_backward_meta.h>
|
| 899 |
+
#include <ATen/ops/msort_meta.h>
|
| 900 |
+
#include <ATen/ops/mul_meta.h>
|
| 901 |
+
#include <ATen/ops/multi_margin_loss_meta.h>
|
| 902 |
+
#include <ATen/ops/multi_margin_loss_backward_meta.h>
|
| 903 |
+
#include <ATen/ops/multilabel_margin_loss_meta.h>
|
| 904 |
+
#include <ATen/ops/multilabel_margin_loss_backward_meta.h>
|
| 905 |
+
#include <ATen/ops/multilabel_margin_loss_forward_meta.h>
|
| 906 |
+
#include <ATen/ops/multinomial_meta.h>
|
| 907 |
+
#include <ATen/ops/multiply_meta.h>
|
| 908 |
+
#include <ATen/ops/mv_meta.h>
|
| 909 |
+
#include <ATen/ops/mvlgamma_meta.h>
|
| 910 |
+
#include <ATen/ops/nan_to_num_meta.h>
|
| 911 |
+
#include <ATen/ops/nanmean_meta.h>
|
| 912 |
+
#include <ATen/ops/nanmedian_meta.h>
|
| 913 |
+
#include <ATen/ops/nanquantile_meta.h>
|
| 914 |
+
#include <ATen/ops/nansum_meta.h>
|
| 915 |
+
#include <ATen/ops/narrow_meta.h>
|
| 916 |
+
#include <ATen/ops/narrow_copy_meta.h>
|
| 917 |
+
#include <ATen/ops/native_batch_norm_meta.h>
|
| 918 |
+
#include <ATen/ops/native_batch_norm_backward_meta.h>
|
| 919 |
+
#include <ATen/ops/native_channel_shuffle_meta.h>
|
| 920 |
+
#include <ATen/ops/native_dropout_meta.h>
|
| 921 |
+
#include <ATen/ops/native_dropout_backward_meta.h>
|
| 922 |
+
#include <ATen/ops/native_group_norm_meta.h>
|
| 923 |
+
#include <ATen/ops/native_group_norm_backward_meta.h>
|
| 924 |
+
#include <ATen/ops/native_layer_norm_meta.h>
|
| 925 |
+
#include <ATen/ops/native_layer_norm_backward_meta.h>
|
| 926 |
+
#include <ATen/ops/native_norm_meta.h>
|
| 927 |
+
#include <ATen/ops/ne_meta.h>
|
| 928 |
+
#include <ATen/ops/neg_meta.h>
|
| 929 |
+
#include <ATen/ops/negative_meta.h>
|
| 930 |
+
#include <ATen/ops/nested_to_padded_tensor_meta.h>
|
| 931 |
+
#include <ATen/ops/new_empty_meta.h>
|
| 932 |
+
#include <ATen/ops/new_empty_strided_meta.h>
|
| 933 |
+
#include <ATen/ops/new_full_meta.h>
|
| 934 |
+
#include <ATen/ops/new_ones_meta.h>
|
| 935 |
+
#include <ATen/ops/new_zeros_meta.h>
|
| 936 |
+
#include <ATen/ops/nextafter_meta.h>
|
| 937 |
+
#include <ATen/ops/nll_loss_meta.h>
|
| 938 |
+
#include <ATen/ops/nll_loss2d_meta.h>
|
| 939 |
+
#include <ATen/ops/nll_loss2d_backward_meta.h>
|
| 940 |
+
#include <ATen/ops/nll_loss2d_forward_meta.h>
|
| 941 |
+
#include <ATen/ops/nll_loss_backward_meta.h>
|
| 942 |
+
#include <ATen/ops/nll_loss_forward_meta.h>
|
| 943 |
+
#include <ATen/ops/nll_loss_nd_meta.h>
|
| 944 |
+
#include <ATen/ops/nonzero_meta.h>
|
| 945 |
+
#include <ATen/ops/nonzero_numpy_meta.h>
|
| 946 |
+
#include <ATen/ops/nonzero_static_meta.h>
|
| 947 |
+
#include <ATen/ops/norm_meta.h>
|
| 948 |
+
#include <ATen/ops/norm_except_dim_meta.h>
|
| 949 |
+
#include <ATen/ops/normal_meta.h>
|
| 950 |
+
#include <ATen/ops/not_equal_meta.h>
|
| 951 |
+
#include <ATen/ops/nuclear_norm_meta.h>
|
| 952 |
+
#include <ATen/ops/numpy_T_meta.h>
|
| 953 |
+
#include <ATen/ops/one_hot_meta.h>
|
| 954 |
+
#include <ATen/ops/ones_meta.h>
|
| 955 |
+
#include <ATen/ops/ones_like_meta.h>
|
| 956 |
+
#include <ATen/ops/or_meta.h>
|
| 957 |
+
#include <ATen/ops/orgqr_meta.h>
|
| 958 |
+
#include <ATen/ops/ormqr_meta.h>
|
| 959 |
+
#include <ATen/ops/outer_meta.h>
|
| 960 |
+
#include <ATen/ops/output_nr_meta.h>
|
| 961 |
+
#include <ATen/ops/pad_meta.h>
|
| 962 |
+
#include <ATen/ops/pad_sequence_meta.h>
|
| 963 |
+
#include <ATen/ops/pairwise_distance_meta.h>
|
| 964 |
+
#include <ATen/ops/pdist_meta.h>
|
| 965 |
+
#include <ATen/ops/permute_meta.h>
|
| 966 |
+
#include <ATen/ops/permute_copy_meta.h>
|
| 967 |
+
#include <ATen/ops/pin_memory_meta.h>
|
| 968 |
+
#include <ATen/ops/pinverse_meta.h>
|
| 969 |
+
#include <ATen/ops/pixel_shuffle_meta.h>
|
| 970 |
+
#include <ATen/ops/pixel_unshuffle_meta.h>
|
| 971 |
+
#include <ATen/ops/poisson_meta.h>
|
| 972 |
+
#include <ATen/ops/poisson_nll_loss_meta.h>
|
| 973 |
+
#include <ATen/ops/polar_meta.h>
|
| 974 |
+
#include <ATen/ops/polygamma_meta.h>
|
| 975 |
+
#include <ATen/ops/positive_meta.h>
|
| 976 |
+
#include <ATen/ops/pow_meta.h>
|
| 977 |
+
#include <ATen/ops/prelu_meta.h>
|
| 978 |
+
#include <ATen/ops/prod_meta.h>
|
| 979 |
+
#include <ATen/ops/promote_types_meta.h>
|
| 980 |
+
#include <ATen/ops/put_meta.h>
|
| 981 |
+
#include <ATen/ops/q_per_channel_axis_meta.h>
|
| 982 |
+
#include <ATen/ops/q_per_channel_scales_meta.h>
|
| 983 |
+
#include <ATen/ops/q_per_channel_zero_points_meta.h>
|
| 984 |
+
#include <ATen/ops/q_scale_meta.h>
|
| 985 |
+
#include <ATen/ops/q_zero_point_meta.h>
|
| 986 |
+
#include <ATen/ops/qr_meta.h>
|
| 987 |
+
#include <ATen/ops/qscheme_meta.h>
|
| 988 |
+
#include <ATen/ops/quantile_meta.h>
|
| 989 |
+
#include <ATen/ops/quantize_per_channel_meta.h>
|
| 990 |
+
#include <ATen/ops/quantize_per_tensor_meta.h>
|
| 991 |
+
#include <ATen/ops/quantize_per_tensor_dynamic_meta.h>
|
| 992 |
+
#include <ATen/ops/quantized_batch_norm_meta.h>
|
| 993 |
+
#include <ATen/ops/quantized_gru_cell_meta.h>
|
| 994 |
+
#include <ATen/ops/quantized_lstm_cell_meta.h>
|
| 995 |
+
#include <ATen/ops/quantized_max_pool1d_meta.h>
|
| 996 |
+
#include <ATen/ops/quantized_max_pool2d_meta.h>
|
| 997 |
+
#include <ATen/ops/quantized_max_pool3d_meta.h>
|
| 998 |
+
#include <ATen/ops/quantized_rnn_relu_cell_meta.h>
|
| 999 |
+
#include <ATen/ops/quantized_rnn_tanh_cell_meta.h>
|
| 1000 |
+
#include <ATen/ops/rad2deg_meta.h>
|
| 1001 |
+
#include <ATen/ops/rand_meta.h>
|
| 1002 |
+
#include <ATen/ops/rand_like_meta.h>
|
| 1003 |
+
#include <ATen/ops/randint_meta.h>
|
| 1004 |
+
#include <ATen/ops/randint_like_meta.h>
|
| 1005 |
+
#include <ATen/ops/randn_meta.h>
|
| 1006 |
+
#include <ATen/ops/randn_like_meta.h>
|
| 1007 |
+
#include <ATen/ops/random_meta.h>
|
| 1008 |
+
#include <ATen/ops/randperm_meta.h>
|
| 1009 |
+
#include <ATen/ops/range_meta.h>
|
| 1010 |
+
#include <ATen/ops/ravel_meta.h>
|
| 1011 |
+
#include <ATen/ops/real_meta.h>
|
| 1012 |
+
#include <ATen/ops/reciprocal_meta.h>
|
| 1013 |
+
#include <ATen/ops/record_stream_meta.h>
|
| 1014 |
+
#include <ATen/ops/refine_names_meta.h>
|
| 1015 |
+
#include <ATen/ops/reflection_pad1d_meta.h>
|
| 1016 |
+
#include <ATen/ops/reflection_pad1d_backward_meta.h>
|
| 1017 |
+
#include <ATen/ops/reflection_pad2d_meta.h>
|
| 1018 |
+
#include <ATen/ops/reflection_pad2d_backward_meta.h>
|
| 1019 |
+
#include <ATen/ops/reflection_pad3d_meta.h>
|
| 1020 |
+
#include <ATen/ops/reflection_pad3d_backward_meta.h>
|
| 1021 |
+
#include <ATen/ops/relu_meta.h>
|
| 1022 |
+
#include <ATen/ops/relu6_meta.h>
|
| 1023 |
+
#include <ATen/ops/remainder_meta.h>
|
| 1024 |
+
#include <ATen/ops/rename_meta.h>
|
| 1025 |
+
#include <ATen/ops/renorm_meta.h>
|
| 1026 |
+
#include <ATen/ops/repeat_meta.h>
|
| 1027 |
+
#include <ATen/ops/repeat_interleave_meta.h>
|
| 1028 |
+
#include <ATen/ops/replication_pad1d_meta.h>
|
| 1029 |
+
#include <ATen/ops/replication_pad1d_backward_meta.h>
|
| 1030 |
+
#include <ATen/ops/replication_pad2d_meta.h>
|
| 1031 |
+
#include <ATen/ops/replication_pad2d_backward_meta.h>
|
| 1032 |
+
#include <ATen/ops/replication_pad3d_meta.h>
|
| 1033 |
+
#include <ATen/ops/replication_pad3d_backward_meta.h>
|
| 1034 |
+
#include <ATen/ops/requires_grad_meta.h>
|
| 1035 |
+
#include <ATen/ops/reshape_meta.h>
|
| 1036 |
+
#include <ATen/ops/reshape_as_meta.h>
|
| 1037 |
+
#include <ATen/ops/resize_meta.h>
|
| 1038 |
+
#include <ATen/ops/resize_as_meta.h>
|
| 1039 |
+
#include <ATen/ops/resize_as_sparse_meta.h>
|
| 1040 |
+
#include <ATen/ops/resolve_conj_meta.h>
|
| 1041 |
+
#include <ATen/ops/resolve_neg_meta.h>
|
| 1042 |
+
#include <ATen/ops/result_type_meta.h>
|
| 1043 |
+
#include <ATen/ops/retain_grad_meta.h>
|
| 1044 |
+
#include <ATen/ops/retains_grad_meta.h>
|
| 1045 |
+
#include <ATen/ops/rnn_relu_meta.h>
|
| 1046 |
+
#include <ATen/ops/rnn_relu_cell_meta.h>
|
| 1047 |
+
#include <ATen/ops/rnn_tanh_meta.h>
|
| 1048 |
+
#include <ATen/ops/rnn_tanh_cell_meta.h>
|
| 1049 |
+
#include <ATen/ops/roll_meta.h>
|
| 1050 |
+
#include <ATen/ops/rot90_meta.h>
|
| 1051 |
+
#include <ATen/ops/round_meta.h>
|
| 1052 |
+
#include <ATen/ops/row_indices_meta.h>
|
| 1053 |
+
#include <ATen/ops/row_indices_copy_meta.h>
|
| 1054 |
+
#include <ATen/ops/row_stack_meta.h>
|
| 1055 |
+
#include <ATen/ops/rrelu_meta.h>
|
| 1056 |
+
#include <ATen/ops/rrelu_with_noise_meta.h>
|
| 1057 |
+
#include <ATen/ops/rrelu_with_noise_backward_meta.h>
|
| 1058 |
+
#include <ATen/ops/rshift_meta.h>
|
| 1059 |
+
#include <ATen/ops/rsqrt_meta.h>
|
| 1060 |
+
#include <ATen/ops/rsub_meta.h>
|
| 1061 |
+
#include <ATen/ops/scalar_tensor_meta.h>
|
| 1062 |
+
#include <ATen/ops/scaled_dot_product_attention_meta.h>
|
| 1063 |
+
#include <ATen/ops/scatter_meta.h>
|
| 1064 |
+
#include <ATen/ops/scatter_add_meta.h>
|
| 1065 |
+
#include <ATen/ops/scatter_reduce_meta.h>
|
| 1066 |
+
#include <ATen/ops/searchsorted_meta.h>
|
| 1067 |
+
#include <ATen/ops/segment_reduce_meta.h>
|
| 1068 |
+
#include <ATen/ops/select_meta.h>
|
| 1069 |
+
#include <ATen/ops/select_backward_meta.h>
|
| 1070 |
+
#include <ATen/ops/select_copy_meta.h>
|
| 1071 |
+
#include <ATen/ops/select_scatter_meta.h>
|
| 1072 |
+
#include <ATen/ops/selu_meta.h>
|
| 1073 |
+
#include <ATen/ops/set_meta.h>
|
| 1074 |
+
#include <ATen/ops/set_data_meta.h>
|
| 1075 |
+
#include <ATen/ops/sgn_meta.h>
|
| 1076 |
+
#include <ATen/ops/sigmoid_meta.h>
|
| 1077 |
+
#include <ATen/ops/sigmoid_backward_meta.h>
|
| 1078 |
+
#include <ATen/ops/sign_meta.h>
|
| 1079 |
+
#include <ATen/ops/signbit_meta.h>
|
| 1080 |
+
#include <ATen/ops/silu_meta.h>
|
| 1081 |
+
#include <ATen/ops/silu_backward_meta.h>
|
| 1082 |
+
#include <ATen/ops/sin_meta.h>
|
| 1083 |
+
#include <ATen/ops/sinc_meta.h>
|
| 1084 |
+
#include <ATen/ops/sinh_meta.h>
|
| 1085 |
+
#include <ATen/ops/size_meta.h>
|
| 1086 |
+
#include <ATen/ops/slice_meta.h>
|
| 1087 |
+
#include <ATen/ops/slice_backward_meta.h>
|
| 1088 |
+
#include <ATen/ops/slice_copy_meta.h>
|
| 1089 |
+
#include <ATen/ops/slice_inverse_meta.h>
|
| 1090 |
+
#include <ATen/ops/slice_scatter_meta.h>
|
| 1091 |
+
#include <ATen/ops/slogdet_meta.h>
|
| 1092 |
+
#include <ATen/ops/slow_conv3d_meta.h>
|
| 1093 |
+
#include <ATen/ops/slow_conv3d_forward_meta.h>
|
| 1094 |
+
#include <ATen/ops/slow_conv_dilated2d_meta.h>
|
| 1095 |
+
#include <ATen/ops/slow_conv_dilated3d_meta.h>
|
| 1096 |
+
#include <ATen/ops/slow_conv_transpose2d_meta.h>
|
| 1097 |
+
#include <ATen/ops/slow_conv_transpose3d_meta.h>
|
| 1098 |
+
#include <ATen/ops/smm_meta.h>
|
| 1099 |
+
#include <ATen/ops/smooth_l1_loss_meta.h>
|
| 1100 |
+
#include <ATen/ops/smooth_l1_loss_backward_meta.h>
|
| 1101 |
+
#include <ATen/ops/soft_margin_loss_meta.h>
|
| 1102 |
+
#include <ATen/ops/soft_margin_loss_backward_meta.h>
|
| 1103 |
+
#include <ATen/ops/softmax_meta.h>
|
| 1104 |
+
#include <ATen/ops/softplus_meta.h>
|
| 1105 |
+
#include <ATen/ops/softplus_backward_meta.h>
|
| 1106 |
+
#include <ATen/ops/softshrink_meta.h>
|
| 1107 |
+
#include <ATen/ops/softshrink_backward_meta.h>
|
| 1108 |
+
#include <ATen/ops/sort_meta.h>
|
| 1109 |
+
#include <ATen/ops/sparse_bsc_tensor_meta.h>
|
| 1110 |
+
#include <ATen/ops/sparse_bsr_tensor_meta.h>
|
| 1111 |
+
#include <ATen/ops/sparse_compressed_tensor_meta.h>
|
| 1112 |
+
#include <ATen/ops/sparse_coo_tensor_meta.h>
|
| 1113 |
+
#include <ATen/ops/sparse_csc_tensor_meta.h>
|
| 1114 |
+
#include <ATen/ops/sparse_csr_tensor_meta.h>
|
| 1115 |
+
#include <ATen/ops/sparse_dim_meta.h>
|
| 1116 |
+
#include <ATen/ops/sparse_mask_meta.h>
|
| 1117 |
+
#include <ATen/ops/sparse_resize_meta.h>
|
| 1118 |
+
#include <ATen/ops/sparse_resize_and_clear_meta.h>
|
| 1119 |
+
#include <ATen/ops/sparse_sampled_addmm_meta.h>
|
| 1120 |
+
#include <ATen/ops/special_airy_ai_meta.h>
|
| 1121 |
+
#include <ATen/ops/special_bessel_j0_meta.h>
|
| 1122 |
+
#include <ATen/ops/special_bessel_j1_meta.h>
|
| 1123 |
+
#include <ATen/ops/special_bessel_y0_meta.h>
|
| 1124 |
+
#include <ATen/ops/special_bessel_y1_meta.h>
|
| 1125 |
+
#include <ATen/ops/special_chebyshev_polynomial_t_meta.h>
|
| 1126 |
+
#include <ATen/ops/special_chebyshev_polynomial_u_meta.h>
|
| 1127 |
+
#include <ATen/ops/special_chebyshev_polynomial_v_meta.h>
|
| 1128 |
+
#include <ATen/ops/special_chebyshev_polynomial_w_meta.h>
|
| 1129 |
+
#include <ATen/ops/special_digamma_meta.h>
|
| 1130 |
+
#include <ATen/ops/special_entr_meta.h>
|
| 1131 |
+
#include <ATen/ops/special_erf_meta.h>
|
| 1132 |
+
#include <ATen/ops/special_erfc_meta.h>
|
| 1133 |
+
#include <ATen/ops/special_erfcx_meta.h>
|
| 1134 |
+
#include <ATen/ops/special_erfinv_meta.h>
|
| 1135 |
+
#include <ATen/ops/special_exp2_meta.h>
|
| 1136 |
+
#include <ATen/ops/special_expit_meta.h>
|
| 1137 |
+
#include <ATen/ops/special_expm1_meta.h>
|
| 1138 |
+
#include <ATen/ops/special_gammainc_meta.h>
|
| 1139 |
+
#include <ATen/ops/special_gammaincc_meta.h>
|
| 1140 |
+
#include <ATen/ops/special_gammaln_meta.h>
|
| 1141 |
+
#include <ATen/ops/special_hermite_polynomial_h_meta.h>
|
| 1142 |
+
#include <ATen/ops/special_hermite_polynomial_he_meta.h>
|
| 1143 |
+
#include <ATen/ops/special_i0_meta.h>
|
| 1144 |
+
#include <ATen/ops/special_i0e_meta.h>
|
| 1145 |
+
#include <ATen/ops/special_i1_meta.h>
|
| 1146 |
+
#include <ATen/ops/special_i1e_meta.h>
|
| 1147 |
+
#include <ATen/ops/special_laguerre_polynomial_l_meta.h>
|
| 1148 |
+
#include <ATen/ops/special_legendre_polynomial_p_meta.h>
|
| 1149 |
+
#include <ATen/ops/special_log1p_meta.h>
|
| 1150 |
+
#include <ATen/ops/special_log_ndtr_meta.h>
|
| 1151 |
+
#include <ATen/ops/special_log_softmax_meta.h>
|
| 1152 |
+
#include <ATen/ops/special_logit_meta.h>
|
| 1153 |
+
#include <ATen/ops/special_logsumexp_meta.h>
|
| 1154 |
+
#include <ATen/ops/special_modified_bessel_i0_meta.h>
|
| 1155 |
+
#include <ATen/ops/special_modified_bessel_i1_meta.h>
|
| 1156 |
+
#include <ATen/ops/special_modified_bessel_k0_meta.h>
|
| 1157 |
+
#include <ATen/ops/special_modified_bessel_k1_meta.h>
|
| 1158 |
+
#include <ATen/ops/special_multigammaln_meta.h>
|
| 1159 |
+
#include <ATen/ops/special_ndtr_meta.h>
|
| 1160 |
+
#include <ATen/ops/special_ndtri_meta.h>
|
| 1161 |
+
#include <ATen/ops/special_polygamma_meta.h>
|
| 1162 |
+
#include <ATen/ops/special_psi_meta.h>
|
| 1163 |
+
#include <ATen/ops/special_round_meta.h>
|
| 1164 |
+
#include <ATen/ops/special_scaled_modified_bessel_k0_meta.h>
|
| 1165 |
+
#include <ATen/ops/special_scaled_modified_bessel_k1_meta.h>
|
| 1166 |
+
#include <ATen/ops/special_shifted_chebyshev_polynomial_t_meta.h>
|
| 1167 |
+
#include <ATen/ops/special_shifted_chebyshev_polynomial_u_meta.h>
|
| 1168 |
+
#include <ATen/ops/special_shifted_chebyshev_polynomial_v_meta.h>
|
| 1169 |
+
#include <ATen/ops/special_shifted_chebyshev_polynomial_w_meta.h>
|
| 1170 |
+
#include <ATen/ops/special_sinc_meta.h>
|
| 1171 |
+
#include <ATen/ops/special_softmax_meta.h>
|
| 1172 |
+
#include <ATen/ops/special_spherical_bessel_j0_meta.h>
|
| 1173 |
+
#include <ATen/ops/special_xlog1py_meta.h>
|
| 1174 |
+
#include <ATen/ops/special_xlogy_meta.h>
|
| 1175 |
+
#include <ATen/ops/special_zeta_meta.h>
|
| 1176 |
+
#include <ATen/ops/split_meta.h>
|
| 1177 |
+
#include <ATen/ops/split_copy_meta.h>
|
| 1178 |
+
#include <ATen/ops/split_with_sizes_meta.h>
|
| 1179 |
+
#include <ATen/ops/split_with_sizes_copy_meta.h>
|
| 1180 |
+
#include <ATen/ops/sqrt_meta.h>
|
| 1181 |
+
#include <ATen/ops/square_meta.h>
|
| 1182 |
+
#include <ATen/ops/squeeze_meta.h>
|
| 1183 |
+
#include <ATen/ops/squeeze_copy_meta.h>
|
| 1184 |
+
#include <ATen/ops/sspaddmm_meta.h>
|
| 1185 |
+
#include <ATen/ops/stack_meta.h>
|
| 1186 |
+
#include <ATen/ops/std_meta.h>
|
| 1187 |
+
#include <ATen/ops/std_mean_meta.h>
|
| 1188 |
+
#include <ATen/ops/stft_meta.h>
|
| 1189 |
+
#include <ATen/ops/stride_meta.h>
|
| 1190 |
+
#include <ATen/ops/sub_meta.h>
|
| 1191 |
+
#include <ATen/ops/subtract_meta.h>
|
| 1192 |
+
#include <ATen/ops/sum_meta.h>
|
| 1193 |
+
#include <ATen/ops/sum_to_size_meta.h>
|
| 1194 |
+
#include <ATen/ops/svd_meta.h>
|
| 1195 |
+
#include <ATen/ops/swapaxes_meta.h>
|
| 1196 |
+
#include <ATen/ops/swapdims_meta.h>
|
| 1197 |
+
#include <ATen/ops/sym_constrain_range_meta.h>
|
| 1198 |
+
#include <ATen/ops/sym_constrain_range_for_size_meta.h>
|
| 1199 |
+
#include <ATen/ops/sym_numel_meta.h>
|
| 1200 |
+
#include <ATen/ops/sym_size_meta.h>
|
| 1201 |
+
#include <ATen/ops/sym_storage_offset_meta.h>
|
| 1202 |
+
#include <ATen/ops/sym_stride_meta.h>
|
| 1203 |
+
#include <ATen/ops/t_meta.h>
|
| 1204 |
+
#include <ATen/ops/t_copy_meta.h>
|
| 1205 |
+
#include <ATen/ops/take_meta.h>
|
| 1206 |
+
#include <ATen/ops/take_along_dim_meta.h>
|
| 1207 |
+
#include <ATen/ops/tan_meta.h>
|
| 1208 |
+
#include <ATen/ops/tanh_meta.h>
|
| 1209 |
+
#include <ATen/ops/tanh_backward_meta.h>
|
| 1210 |
+
#include <ATen/ops/tensor_split_meta.h>
|
| 1211 |
+
#include <ATen/ops/tensordot_meta.h>
|
| 1212 |
+
#include <ATen/ops/thnn_conv2d_meta.h>
|
| 1213 |
+
#include <ATen/ops/threshold_meta.h>
|
| 1214 |
+
#include <ATen/ops/threshold_backward_meta.h>
|
| 1215 |
+
#include <ATen/ops/tile_meta.h>
|
| 1216 |
+
#include <ATen/ops/to_meta.h>
|
| 1217 |
+
#include <ATen/ops/to_dense_meta.h>
|
| 1218 |
+
#include <ATen/ops/to_dense_backward_meta.h>
|
| 1219 |
+
#include <ATen/ops/to_mkldnn_meta.h>
|
| 1220 |
+
#include <ATen/ops/to_mkldnn_backward_meta.h>
|
| 1221 |
+
#include <ATen/ops/to_padded_tensor_meta.h>
|
| 1222 |
+
#include <ATen/ops/to_sparse_meta.h>
|
| 1223 |
+
#include <ATen/ops/to_sparse_bsc_meta.h>
|
| 1224 |
+
#include <ATen/ops/to_sparse_bsr_meta.h>
|
| 1225 |
+
#include <ATen/ops/to_sparse_csc_meta.h>
|
| 1226 |
+
#include <ATen/ops/to_sparse_csr_meta.h>
|
| 1227 |
+
#include <ATen/ops/topk_meta.h>
|
| 1228 |
+
#include <ATen/ops/trace_meta.h>
|
| 1229 |
+
#include <ATen/ops/trace_backward_meta.h>
|
| 1230 |
+
#include <ATen/ops/transpose_meta.h>
|
| 1231 |
+
#include <ATen/ops/transpose_copy_meta.h>
|
| 1232 |
+
#include <ATen/ops/trapezoid_meta.h>
|
| 1233 |
+
#include <ATen/ops/trapz_meta.h>
|
| 1234 |
+
#include <ATen/ops/triangular_solve_meta.h>
|
| 1235 |
+
#include <ATen/ops/tril_meta.h>
|
| 1236 |
+
#include <ATen/ops/tril_indices_meta.h>
|
| 1237 |
+
#include <ATen/ops/triplet_margin_loss_meta.h>
|
| 1238 |
+
#include <ATen/ops/triu_meta.h>
|
| 1239 |
+
#include <ATen/ops/triu_indices_meta.h>
|
| 1240 |
+
#include <ATen/ops/true_divide_meta.h>
|
| 1241 |
+
#include <ATen/ops/trunc_meta.h>
|
| 1242 |
+
#include <ATen/ops/type_as_meta.h>
|
| 1243 |
+
#include <ATen/ops/unbind_meta.h>
|
| 1244 |
+
#include <ATen/ops/unbind_copy_meta.h>
|
| 1245 |
+
#include <ATen/ops/unflatten_meta.h>
|
| 1246 |
+
#include <ATen/ops/unflatten_dense_tensors_meta.h>
|
| 1247 |
+
#include <ATen/ops/unfold_meta.h>
|
| 1248 |
+
#include <ATen/ops/unfold_backward_meta.h>
|
| 1249 |
+
#include <ATen/ops/unfold_copy_meta.h>
|
| 1250 |
+
#include <ATen/ops/uniform_meta.h>
|
| 1251 |
+
#include <ATen/ops/unique_consecutive_meta.h>
|
| 1252 |
+
#include <ATen/ops/unique_dim_meta.h>
|
| 1253 |
+
#include <ATen/ops/unique_dim_consecutive_meta.h>
|
| 1254 |
+
#include <ATen/ops/unsafe_chunk_meta.h>
|
| 1255 |
+
#include <ATen/ops/unsafe_split_meta.h>
|
| 1256 |
+
#include <ATen/ops/unsafe_split_with_sizes_meta.h>
|
| 1257 |
+
#include <ATen/ops/unsqueeze_meta.h>
|
| 1258 |
+
#include <ATen/ops/unsqueeze_copy_meta.h>
|
| 1259 |
+
#include <ATen/ops/upsample_bicubic2d_meta.h>
|
| 1260 |
+
#include <ATen/ops/upsample_bicubic2d_backward_meta.h>
|
| 1261 |
+
#include <ATen/ops/upsample_bilinear2d_meta.h>
|
| 1262 |
+
#include <ATen/ops/upsample_bilinear2d_backward_meta.h>
|
| 1263 |
+
#include <ATen/ops/upsample_linear1d_meta.h>
|
| 1264 |
+
#include <ATen/ops/upsample_linear1d_backward_meta.h>
|
| 1265 |
+
#include <ATen/ops/upsample_nearest1d_meta.h>
|
| 1266 |
+
#include <ATen/ops/upsample_nearest1d_backward_meta.h>
|
| 1267 |
+
#include <ATen/ops/upsample_nearest2d_meta.h>
|
| 1268 |
+
#include <ATen/ops/upsample_nearest2d_backward_meta.h>
|
| 1269 |
+
#include <ATen/ops/upsample_nearest3d_meta.h>
|
| 1270 |
+
#include <ATen/ops/upsample_nearest3d_backward_meta.h>
|
| 1271 |
+
#include <ATen/ops/upsample_trilinear3d_meta.h>
|
| 1272 |
+
#include <ATen/ops/upsample_trilinear3d_backward_meta.h>
|
| 1273 |
+
#include <ATen/ops/value_selecting_reduction_backward_meta.h>
|
| 1274 |
+
#include <ATen/ops/values_meta.h>
|
| 1275 |
+
#include <ATen/ops/values_copy_meta.h>
|
| 1276 |
+
#include <ATen/ops/vander_meta.h>
|
| 1277 |
+
#include <ATen/ops/var_meta.h>
|
| 1278 |
+
#include <ATen/ops/var_mean_meta.h>
|
| 1279 |
+
#include <ATen/ops/vdot_meta.h>
|
| 1280 |
+
#include <ATen/ops/view_meta.h>
|
| 1281 |
+
#include <ATen/ops/view_as_meta.h>
|
| 1282 |
+
#include <ATen/ops/view_as_complex_meta.h>
|
| 1283 |
+
#include <ATen/ops/view_as_complex_copy_meta.h>
|
| 1284 |
+
#include <ATen/ops/view_as_real_meta.h>
|
| 1285 |
+
#include <ATen/ops/view_as_real_copy_meta.h>
|
| 1286 |
+
#include <ATen/ops/view_copy_meta.h>
|
| 1287 |
+
#include <ATen/ops/vsplit_meta.h>
|
| 1288 |
+
#include <ATen/ops/vstack_meta.h>
|
| 1289 |
+
#include <ATen/ops/where_meta.h>
|
| 1290 |
+
#include <ATen/ops/xlogy_meta.h>
|
| 1291 |
+
#include <ATen/ops/xor_meta.h>
|
| 1292 |
+
#include <ATen/ops/zero_meta.h>
|
| 1293 |
+
#include <ATen/ops/zeros_meta.h>
|
| 1294 |
+
#include <ATen/ops/zeros_like_meta.h>
|
| 1295 |
+
|
| 1296 |
+
namespace at {
|
| 1297 |
+
|
| 1298 |
+
namespace meta {
|
| 1299 |
+
|
| 1300 |
+
|
| 1301 |
+
|
| 1302 |
+
} // namespace meta
|
| 1303 |
+
} // namespace at
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/NumericUtils.h
ADDED
|
@@ -0,0 +1,203 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#ifdef __HIPCC__
|
| 4 |
+
#include <hip/hip_runtime.h>
|
| 5 |
+
#endif
|
| 6 |
+
|
| 7 |
+
#include <c10/macros/Macros.h>
|
| 8 |
+
#include <c10/util/BFloat16.h>
|
| 9 |
+
#include <c10/util/Float8_e4m3fn.h>
|
| 10 |
+
#include <c10/util/Float8_e4m3fnuz.h>
|
| 11 |
+
#include <c10/util/Float8_e5m2.h>
|
| 12 |
+
#include <c10/util/Float8_e5m2fnuz.h>
|
| 13 |
+
#include <c10/util/Half.h>
|
| 14 |
+
#include <c10/util/complex.h>
|
| 15 |
+
|
| 16 |
+
#include <cmath>
|
| 17 |
+
#include <type_traits>
|
| 18 |
+
|
| 19 |
+
namespace at {
|
| 20 |
+
|
| 21 |
+
// std::isnan isn't performant to use on integral types; it will
|
| 22 |
+
// (uselessly) convert to floating point and then do the test.
|
| 23 |
+
// This function is.
|
| 24 |
+
|
| 25 |
+
template <typename T, std::enable_if_t<std::is_integral_v<T>, int> = 0>
|
| 26 |
+
inline C10_HOST_DEVICE bool _isnan(T /*val*/) {
|
| 27 |
+
return false;
|
| 28 |
+
}
|
| 29 |
+
|
| 30 |
+
template <typename T, std::enable_if_t<std::is_floating_point_v<T>, int> = 0>
|
| 31 |
+
inline C10_HOST_DEVICE bool _isnan(T val) {
|
| 32 |
+
#if defined(__CUDACC__) || defined(__HIPCC__)
|
| 33 |
+
return ::isnan(val);
|
| 34 |
+
#else
|
| 35 |
+
return std::isnan(val);
|
| 36 |
+
#endif
|
| 37 |
+
}
|
| 38 |
+
|
| 39 |
+
template <typename T, std::enable_if_t<c10::is_complex<T>::value, int> = 0>
|
| 40 |
+
inline C10_HOST_DEVICE bool _isnan(T val) {
|
| 41 |
+
return std::isnan(val.real()) || std::isnan(val.imag());
|
| 42 |
+
}
|
| 43 |
+
|
| 44 |
+
template <typename T, std::enable_if_t<std::is_same_v<T, at::Half>, int> = 0>
|
| 45 |
+
inline C10_HOST_DEVICE bool _isnan(T val) {
|
| 46 |
+
return at::_isnan(static_cast<float>(val));
|
| 47 |
+
}
|
| 48 |
+
|
| 49 |
+
template <
|
| 50 |
+
typename T,
|
| 51 |
+
std::enable_if_t<std::is_same_v<T, at::BFloat16>, int> = 0>
|
| 52 |
+
inline C10_HOST_DEVICE bool _isnan(at::BFloat16 val) {
|
| 53 |
+
return at::_isnan(static_cast<float>(val));
|
| 54 |
+
}
|
| 55 |
+
|
| 56 |
+
inline C10_HOST_DEVICE bool _isnan(at::BFloat16 val) {
|
| 57 |
+
return at::_isnan(static_cast<float>(val));
|
| 58 |
+
}
|
| 59 |
+
|
| 60 |
+
template <
|
| 61 |
+
typename T,
|
| 62 |
+
std::enable_if_t<std::is_same_v<T, at::Float8_e5m2>, int> = 0>
|
| 63 |
+
inline C10_HOST_DEVICE bool _isnan(T val) {
|
| 64 |
+
return val.isnan();
|
| 65 |
+
}
|
| 66 |
+
|
| 67 |
+
template <
|
| 68 |
+
typename T,
|
| 69 |
+
std::enable_if_t<std::is_same_v<T, at::Float8_e4m3fn>, int> = 0>
|
| 70 |
+
inline C10_HOST_DEVICE bool _isnan(T val) {
|
| 71 |
+
return val.isnan();
|
| 72 |
+
}
|
| 73 |
+
|
| 74 |
+
template <
|
| 75 |
+
typename T,
|
| 76 |
+
std::enable_if_t<std::is_same_v<T, at::Float8_e5m2fnuz>, int> = 0>
|
| 77 |
+
inline C10_HOST_DEVICE bool _isnan(T val) {
|
| 78 |
+
return val.isnan();
|
| 79 |
+
}
|
| 80 |
+
|
| 81 |
+
template <
|
| 82 |
+
typename T,
|
| 83 |
+
std::enable_if_t<std::is_same_v<T, at::Float8_e4m3fnuz>, int> = 0>
|
| 84 |
+
inline C10_HOST_DEVICE bool _isnan(T val) {
|
| 85 |
+
return val.isnan();
|
| 86 |
+
}
|
| 87 |
+
|
| 88 |
+
// std::isinf isn't performant to use on integral types; it will
|
| 89 |
+
// (uselessly) convert to floating point and then do the test.
|
| 90 |
+
// This function is.
|
| 91 |
+
|
| 92 |
+
template <typename T, std::enable_if_t<std::is_integral_v<T>, int> = 0>
|
| 93 |
+
inline C10_HOST_DEVICE bool _isinf(T /*val*/) {
|
| 94 |
+
return false;
|
| 95 |
+
}
|
| 96 |
+
|
| 97 |
+
template <typename T, std::enable_if_t<std::is_floating_point_v<T>, int> = 0>
|
| 98 |
+
inline C10_HOST_DEVICE bool _isinf(T val) {
|
| 99 |
+
#if defined(__CUDACC__) || defined(__HIPCC__)
|
| 100 |
+
return ::isinf(val);
|
| 101 |
+
#else
|
| 102 |
+
return std::isinf(val);
|
| 103 |
+
#endif
|
| 104 |
+
}
|
| 105 |
+
|
| 106 |
+
inline C10_HOST_DEVICE bool _isinf(at::Half val) {
|
| 107 |
+
return at::_isinf(static_cast<float>(val));
|
| 108 |
+
}
|
| 109 |
+
|
| 110 |
+
inline C10_HOST_DEVICE bool _isinf(at::BFloat16 val) {
|
| 111 |
+
return at::_isinf(static_cast<float>(val));
|
| 112 |
+
}
|
| 113 |
+
|
| 114 |
+
inline C10_HOST_DEVICE bool _isinf(at::Float8_e5m2 val) {
|
| 115 |
+
return val.isinf();
|
| 116 |
+
}
|
| 117 |
+
|
| 118 |
+
inline C10_HOST_DEVICE bool _isinf(at::Float8_e4m3fn val) {
|
| 119 |
+
return false;
|
| 120 |
+
}
|
| 121 |
+
|
| 122 |
+
inline C10_HOST_DEVICE bool _isinf(at::Float8_e5m2fnuz val) {
|
| 123 |
+
return false;
|
| 124 |
+
}
|
| 125 |
+
|
| 126 |
+
inline C10_HOST_DEVICE bool _isinf(at::Float8_e4m3fnuz val) {
|
| 127 |
+
return false;
|
| 128 |
+
}
|
| 129 |
+
|
| 130 |
+
template <typename T>
|
| 131 |
+
C10_HOST_DEVICE inline T exp(T x) {
|
| 132 |
+
static_assert(
|
| 133 |
+
!std::is_same_v<T, double>,
|
| 134 |
+
"this template must be used with float or less precise type");
|
| 135 |
+
#if defined(__CUDA_ARCH__) || defined(__HIP_ARCH__)
|
| 136 |
+
// use __expf fast approximation for peak bandwidth
|
| 137 |
+
return __expf(x);
|
| 138 |
+
#else
|
| 139 |
+
return ::exp(x);
|
| 140 |
+
#endif
|
| 141 |
+
}
|
| 142 |
+
|
| 143 |
+
template <>
|
| 144 |
+
C10_HOST_DEVICE inline double exp<double>(double x) {
|
| 145 |
+
return ::exp(x);
|
| 146 |
+
}
|
| 147 |
+
|
| 148 |
+
template <typename T>
|
| 149 |
+
C10_HOST_DEVICE inline T log(T x) {
|
| 150 |
+
static_assert(
|
| 151 |
+
!std::is_same_v<T, double>,
|
| 152 |
+
"this template must be used with float or less precise type");
|
| 153 |
+
#if defined(__CUDA_ARCH__) || defined(__HIP_ARCH__)
|
| 154 |
+
// use __logf fast approximation for peak bandwidth
|
| 155 |
+
return __logf(x);
|
| 156 |
+
#else
|
| 157 |
+
return ::log(x);
|
| 158 |
+
#endif
|
| 159 |
+
}
|
| 160 |
+
|
| 161 |
+
template <>
|
| 162 |
+
C10_HOST_DEVICE inline double log<double>(double x) {
|
| 163 |
+
return ::log(x);
|
| 164 |
+
}
|
| 165 |
+
|
| 166 |
+
template <typename T>
|
| 167 |
+
C10_HOST_DEVICE inline T log1p(T x) {
|
| 168 |
+
static_assert(
|
| 169 |
+
!std::is_same_v<T, double>,
|
| 170 |
+
"this template must be used with float or less precise type");
|
| 171 |
+
#if defined(__CUDA_ARCH__) || defined(__HIP_ARCH__)
|
| 172 |
+
// use __logf fast approximation for peak bandwidth
|
| 173 |
+
// NOTE: There is no __log1pf so unfortunately we lose precision.
|
| 174 |
+
return __logf(1.0f + x);
|
| 175 |
+
#else
|
| 176 |
+
return ::log1p(x);
|
| 177 |
+
#endif
|
| 178 |
+
}
|
| 179 |
+
|
| 180 |
+
template <>
|
| 181 |
+
C10_HOST_DEVICE inline double log1p<double>(double x) {
|
| 182 |
+
return ::log1p(x);
|
| 183 |
+
}
|
| 184 |
+
|
| 185 |
+
template <typename T>
|
| 186 |
+
C10_HOST_DEVICE inline T tan(T x) {
|
| 187 |
+
static_assert(
|
| 188 |
+
!std::is_same_v<T, double>,
|
| 189 |
+
"this template must be used with float or less precise type");
|
| 190 |
+
#if defined(__CUDA_ARCH__) || defined(__HIP_ARCH__)
|
| 191 |
+
// use __tanf fast approximation for peak bandwidth
|
| 192 |
+
return __tanf(x);
|
| 193 |
+
#else
|
| 194 |
+
return ::tan(x);
|
| 195 |
+
#endif
|
| 196 |
+
}
|
| 197 |
+
|
| 198 |
+
template <>
|
| 199 |
+
C10_HOST_DEVICE inline double tan<double>(double x) {
|
| 200 |
+
return ::tan(x);
|
| 201 |
+
}
|
| 202 |
+
|
| 203 |
+
} // namespace at
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/OpaqueTensorImpl.h
ADDED
|
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <c10/core/MemoryFormat.h>
|
| 4 |
+
#include <c10/core/SymIntArrayRef.h>
|
| 5 |
+
#include <c10/core/TensorImpl.h>
|
| 6 |
+
#include <c10/util/Exception.h>
|
| 7 |
+
|
| 8 |
+
namespace at {
|
| 9 |
+
|
| 10 |
+
// An "Opaque" TensorImpl -- there are no strides and (for now)
|
| 11 |
+
// even data() is not supported (thus no pointer arithmetic).
|
| 12 |
+
|
| 13 |
+
// NOTE: We could allow data() in the future, but would have to ensure pointer
|
| 14 |
+
// arithmetic code is properly guarded.
|
| 15 |
+
//
|
| 16 |
+
// NOTE: This does not support resize_ (and other metadata-changing ops) because
|
| 17 |
+
// of `shallow_copy_and_detach`. We would need to define an interface to
|
| 18 |
+
// "shallow copy" in order to add support.
|
| 19 |
+
|
| 20 |
+
template <typename OpaqueHandle>
|
| 21 |
+
struct TORCH_API OpaqueTensorImpl : public TensorImpl {
|
| 22 |
+
// public constructor for now...
|
| 23 |
+
OpaqueTensorImpl(
|
| 24 |
+
at::DispatchKeySet key_set,
|
| 25 |
+
const caffe2::TypeMeta data_type,
|
| 26 |
+
c10::Device device,
|
| 27 |
+
OpaqueHandle opaque_handle,
|
| 28 |
+
c10::IntArrayRef sizes,
|
| 29 |
+
bool is_non_overlapping_and_dense = true)
|
| 30 |
+
: TensorImpl(key_set, data_type, device),
|
| 31 |
+
opaque_handle_(std::move(opaque_handle)) {
|
| 32 |
+
set_storage_access_should_throw();
|
| 33 |
+
set_custom_sizes_strides(SizesStridesPolicy::CustomStrides);
|
| 34 |
+
sizes_and_strides_.set_sizes(sizes);
|
| 35 |
+
refresh_numel();
|
| 36 |
+
// NOLINTNEXTLINE(cppcoreguidelines-prefer-member-initializer)
|
| 37 |
+
is_non_overlapping_and_dense_ = is_non_overlapping_and_dense;
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
// Destructor doesn't call release_resources because it's
|
| 41 |
+
// unnecessary; don't forget to change that if needed!
|
| 42 |
+
void release_resources() override {
|
| 43 |
+
TensorImpl::release_resources();
|
| 44 |
+
opaque_handle_ = {};
|
| 45 |
+
}
|
| 46 |
+
|
| 47 |
+
void set_size(int64_t dim, int64_t new_size) override {
|
| 48 |
+
AT_ERROR("opaque tensors do not have set_size");
|
| 49 |
+
}
|
| 50 |
+
|
| 51 |
+
void set_stride(int64_t dim, int64_t new_stride) override {
|
| 52 |
+
AT_ERROR("opaque tensors do not have set_stride");
|
| 53 |
+
}
|
| 54 |
+
|
| 55 |
+
void set_storage_offset(int64_t storage_offset) override {
|
| 56 |
+
AT_ERROR("opaque tensors do not have set_storage_offset");
|
| 57 |
+
}
|
| 58 |
+
|
| 59 |
+
#ifdef DEBUG
|
| 60 |
+
bool has_storage() const override {
|
| 61 |
+
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
|
| 62 |
+
!storage_, "OpaqueTensorImpl assumes that storage_ is never set");
|
| 63 |
+
return false;
|
| 64 |
+
}
|
| 65 |
+
#endif
|
| 66 |
+
|
| 67 |
+
/**
|
| 68 |
+
* Return a TensorImpl that is a shallow-copy of this TensorImpl.
|
| 69 |
+
*
|
| 70 |
+
* For usage of `version_counter` and `allow_tensor_metadata_change`,
|
| 71 |
+
* see NOTE [ TensorImpl Shallow-Copying ].
|
| 72 |
+
*/
|
| 73 |
+
c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach(
|
| 74 |
+
const c10::VariableVersion& version_counter,
|
| 75 |
+
bool allow_tensor_metadata_change) const override {
|
| 76 |
+
auto impl = c10::make_intrusive<OpaqueTensorImpl<OpaqueHandle>>(
|
| 77 |
+
key_set(),
|
| 78 |
+
dtype(),
|
| 79 |
+
device(),
|
| 80 |
+
opaque_handle_,
|
| 81 |
+
sizes_and_strides_.sizes_arrayref());
|
| 82 |
+
copy_tensor_metadata(
|
| 83 |
+
/*src_opaque_impl=*/this,
|
| 84 |
+
/*dest_opaque_impl=*/impl.get(),
|
| 85 |
+
/*version_counter=*/version_counter,
|
| 86 |
+
/*allow_tensor_metadata_change=*/allow_tensor_metadata_change);
|
| 87 |
+
impl->refresh_numel();
|
| 88 |
+
return impl;
|
| 89 |
+
}
|
| 90 |
+
|
| 91 |
+
/**
|
| 92 |
+
* Return a TensorImpl that is a shallow-copy of this TensorImpl.
|
| 93 |
+
*
|
| 94 |
+
* For usage of `version_counter` and `allow_tensor_metadata_change`,
|
| 95 |
+
* see NOTE [ TensorImpl Shallow-Copying ].
|
| 96 |
+
*/
|
| 97 |
+
c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach(
|
| 98 |
+
c10::VariableVersion&& version_counter,
|
| 99 |
+
bool allow_tensor_metadata_change) const override {
|
| 100 |
+
auto impl = c10::make_intrusive<OpaqueTensorImpl<OpaqueHandle>>(
|
| 101 |
+
key_set(),
|
| 102 |
+
dtype(),
|
| 103 |
+
device(),
|
| 104 |
+
opaque_handle_,
|
| 105 |
+
sizes_and_strides_.sizes_arrayref());
|
| 106 |
+
copy_tensor_metadata(
|
| 107 |
+
/*src_opaque_impl=*/this,
|
| 108 |
+
/*dest_opaque_impl=*/impl.get(),
|
| 109 |
+
/*version_counter=*/std::move(version_counter),
|
| 110 |
+
/*allow_tensor_metadata_change=*/allow_tensor_metadata_change);
|
| 111 |
+
impl->refresh_numel();
|
| 112 |
+
return impl;
|
| 113 |
+
}
|
| 114 |
+
|
| 115 |
+
/**
|
| 116 |
+
* Shallow-copies data from another TensorImpl into this TensorImpl.
|
| 117 |
+
*
|
| 118 |
+
* For why this function doesn't check this TensorImpl's
|
| 119 |
+
* `allow_tensor_metadata_change_`, see NOTE [ TensorImpl Shallow-Copying ].
|
| 120 |
+
*/
|
| 121 |
+
void shallow_copy_from(const c10::intrusive_ptr<TensorImpl>& impl) override {
|
| 122 |
+
AT_ASSERT(has_compatible_shallow_copy_type(impl->key_set()));
|
| 123 |
+
auto opaque_impl =
|
| 124 |
+
static_cast<const OpaqueTensorImpl<OpaqueHandle>*>(impl.get());
|
| 125 |
+
copy_tensor_metadata(
|
| 126 |
+
/*src_impl=*/opaque_impl,
|
| 127 |
+
/*dest_impl=*/this,
|
| 128 |
+
/*version_counter=*/version_counter(),
|
| 129 |
+
/*allow_tensor_metadata_change=*/allow_tensor_metadata_change());
|
| 130 |
+
refresh_numel();
|
| 131 |
+
}
|
| 132 |
+
|
| 133 |
+
const OpaqueHandle& opaque_handle() const {
|
| 134 |
+
return opaque_handle_;
|
| 135 |
+
}
|
| 136 |
+
|
| 137 |
+
OpaqueHandle& unsafe_opaque_handle() {
|
| 138 |
+
return opaque_handle_;
|
| 139 |
+
}
|
| 140 |
+
|
| 141 |
+
protected:
|
| 142 |
+
/**
|
| 143 |
+
* Copy the tensor metadata fields (e.g. sizes / strides / storage pointer /
|
| 144 |
+
* storage_offset) from one TensorImpl to another TensorImpl.
|
| 145 |
+
*
|
| 146 |
+
* For usage of `version_counter` and `allow_tensor_metadata_change`, see NOTE
|
| 147 |
+
* [ TensorImpl Shallow-Copying ].
|
| 148 |
+
*/
|
| 149 |
+
static void copy_tensor_metadata(
|
| 150 |
+
const OpaqueTensorImpl<OpaqueHandle>* src_opaque_impl,
|
| 151 |
+
OpaqueTensorImpl<OpaqueHandle>* dest_opaque_impl,
|
| 152 |
+
const c10::VariableVersion& version_counter,
|
| 153 |
+
bool allow_tensor_metadata_change) {
|
| 154 |
+
TensorImpl::copy_tensor_metadata(
|
| 155 |
+
src_opaque_impl,
|
| 156 |
+
dest_opaque_impl,
|
| 157 |
+
version_counter,
|
| 158 |
+
allow_tensor_metadata_change);
|
| 159 |
+
|
| 160 |
+
// OpaqueTensorImpl-specific fields.
|
| 161 |
+
dest_opaque_impl->opaque_handle_ = src_opaque_impl->opaque_handle_;
|
| 162 |
+
}
|
| 163 |
+
|
| 164 |
+
static void copy_tensor_metadata(
|
| 165 |
+
const OpaqueTensorImpl<OpaqueHandle>* src_opaque_impl,
|
| 166 |
+
OpaqueTensorImpl<OpaqueHandle>* dest_opaque_impl,
|
| 167 |
+
c10::VariableVersion&& version_counter,
|
| 168 |
+
bool allow_tensor_metadata_change) {
|
| 169 |
+
TensorImpl::copy_tensor_metadata(
|
| 170 |
+
src_opaque_impl,
|
| 171 |
+
dest_opaque_impl,
|
| 172 |
+
std::move(version_counter),
|
| 173 |
+
allow_tensor_metadata_change);
|
| 174 |
+
|
| 175 |
+
// OpaqueTensorImpl-specific fields.
|
| 176 |
+
dest_opaque_impl->opaque_handle_ = src_opaque_impl->opaque_handle_;
|
| 177 |
+
}
|
| 178 |
+
|
| 179 |
+
private:
|
| 180 |
+
const char* tensorimpl_type_name() const override {
|
| 181 |
+
return "OpaqueTensorImpl";
|
| 182 |
+
}
|
| 183 |
+
|
| 184 |
+
OpaqueHandle opaque_handle_;
|
| 185 |
+
};
|
| 186 |
+
|
| 187 |
+
} // namespace at
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/Operators.h
ADDED
|
@@ -0,0 +1,1358 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
// @generated by torchgen/gen.py from Operators.h
|
| 4 |
+
|
| 5 |
+
#ifdef TORCH_ASSERT_NO_OPERATORS
|
| 6 |
+
#error This change adds a dependency on native_functions.yaml, \
|
| 7 |
+
meaning the file will need to be re-compiled every time an operator \
|
| 8 |
+
is changed or added. Consider if your change would be better placed in \
|
| 9 |
+
another file, or if a more specific header might achieve the same goal. \
|
| 10 |
+
See NOTE: [Tensor vs. TensorBase]
|
| 11 |
+
#endif
|
| 12 |
+
|
| 13 |
+
#if defined(AT_PER_OPERATOR_HEADERS) && defined(TORCH_ASSERT_ONLY_METHOD_OPERATORS)
|
| 14 |
+
#error This change adds a dependency on all pytorch operators, meaning the \
|
| 15 |
+
file will need to be re-compiled every time an operator is changed or added. \
|
| 16 |
+
Consider including a specific operator from <ATen/ops/{my_operator}_ops.h> \
|
| 17 |
+
and see NOTE [TORCH_ASSERT_ONLY_METHOD_OPERATORS].
|
| 18 |
+
#endif
|
| 19 |
+
|
| 20 |
+
#include <c10/core/SymInt.h>
|
| 21 |
+
#include <c10/core/SymIntArrayRef.h>
|
| 22 |
+
#include <c10/core/Scalar.h>
|
| 23 |
+
#include <c10/core/TensorOptions.h>
|
| 24 |
+
#include <c10/core/QScheme.h>
|
| 25 |
+
#include <c10/util/OptionalArrayRef.h>
|
| 26 |
+
#include <tuple>
|
| 27 |
+
#include <vector>
|
| 28 |
+
|
| 29 |
+
#include <ATen/ops/_adaptive_avg_pool2d_ops.h>
|
| 30 |
+
#include <ATen/ops/_adaptive_avg_pool2d_backward_ops.h>
|
| 31 |
+
#include <ATen/ops/_adaptive_avg_pool3d_ops.h>
|
| 32 |
+
#include <ATen/ops/_adaptive_avg_pool3d_backward_ops.h>
|
| 33 |
+
#include <ATen/ops/_add_batch_dim_ops.h>
|
| 34 |
+
#include <ATen/ops/_add_relu_ops.h>
|
| 35 |
+
#include <ATen/ops/_addmm_activation_ops.h>
|
| 36 |
+
#include <ATen/ops/_aminmax_ops.h>
|
| 37 |
+
#include <ATen/ops/_amp_foreach_non_finite_check_and_unscale_ops.h>
|
| 38 |
+
#include <ATen/ops/_amp_update_scale_ops.h>
|
| 39 |
+
#include <ATen/ops/_assert_async_ops.h>
|
| 40 |
+
#include <ATen/ops/_assert_scalar_ops.h>
|
| 41 |
+
#include <ATen/ops/_assert_tensor_metadata_ops.h>
|
| 42 |
+
#include <ATen/ops/_autocast_to_full_precision_ops.h>
|
| 43 |
+
#include <ATen/ops/_autocast_to_reduced_precision_ops.h>
|
| 44 |
+
#include <ATen/ops/_backward_ops.h>
|
| 45 |
+
#include <ATen/ops/_batch_norm_impl_index_ops.h>
|
| 46 |
+
#include <ATen/ops/_batch_norm_impl_index_backward_ops.h>
|
| 47 |
+
#include <ATen/ops/_cast_Byte_ops.h>
|
| 48 |
+
#include <ATen/ops/_cast_Char_ops.h>
|
| 49 |
+
#include <ATen/ops/_cast_Double_ops.h>
|
| 50 |
+
#include <ATen/ops/_cast_Float_ops.h>
|
| 51 |
+
#include <ATen/ops/_cast_Half_ops.h>
|
| 52 |
+
#include <ATen/ops/_cast_Int_ops.h>
|
| 53 |
+
#include <ATen/ops/_cast_Long_ops.h>
|
| 54 |
+
#include <ATen/ops/_cast_Short_ops.h>
|
| 55 |
+
#include <ATen/ops/_cdist_backward_ops.h>
|
| 56 |
+
#include <ATen/ops/_cdist_forward_ops.h>
|
| 57 |
+
#include <ATen/ops/_cholesky_solve_helper_ops.h>
|
| 58 |
+
#include <ATen/ops/_choose_qparams_per_tensor_ops.h>
|
| 59 |
+
#include <ATen/ops/_chunk_cat_ops.h>
|
| 60 |
+
#include <ATen/ops/_coalesce_ops.h>
|
| 61 |
+
#include <ATen/ops/_coalesced_ops.h>
|
| 62 |
+
#include <ATen/ops/_compute_linear_combination_ops.h>
|
| 63 |
+
#include <ATen/ops/_conj_ops.h>
|
| 64 |
+
#include <ATen/ops/_conj_copy_ops.h>
|
| 65 |
+
#include <ATen/ops/_conj_physical_ops.h>
|
| 66 |
+
#include <ATen/ops/_conv_depthwise2d_ops.h>
|
| 67 |
+
#include <ATen/ops/_convert_indices_from_coo_to_csr_ops.h>
|
| 68 |
+
#include <ATen/ops/_convert_indices_from_csr_to_coo_ops.h>
|
| 69 |
+
#include <ATen/ops/_convert_weight_to_int4pack_ops.h>
|
| 70 |
+
#include <ATen/ops/_convolution_ops.h>
|
| 71 |
+
#include <ATen/ops/_convolution_double_backward_ops.h>
|
| 72 |
+
#include <ATen/ops/_convolution_mode_ops.h>
|
| 73 |
+
#include <ATen/ops/_copy_from_ops.h>
|
| 74 |
+
#include <ATen/ops/_copy_from_and_resize_ops.h>
|
| 75 |
+
#include <ATen/ops/_cslt_compress_ops.h>
|
| 76 |
+
#include <ATen/ops/_cslt_sparse_mm_ops.h>
|
| 77 |
+
#include <ATen/ops/_cslt_sparse_mm_search_ops.h>
|
| 78 |
+
#include <ATen/ops/_ctc_loss_ops.h>
|
| 79 |
+
#include <ATen/ops/_ctc_loss_backward_ops.h>
|
| 80 |
+
#include <ATen/ops/_cudnn_ctc_loss_ops.h>
|
| 81 |
+
#include <ATen/ops/_cudnn_init_dropout_state_ops.h>
|
| 82 |
+
#include <ATen/ops/_cudnn_rnn_ops.h>
|
| 83 |
+
#include <ATen/ops/_cudnn_rnn_backward_ops.h>
|
| 84 |
+
#include <ATen/ops/_cudnn_rnn_flatten_weight_ops.h>
|
| 85 |
+
#include <ATen/ops/_cufft_clear_plan_cache_ops.h>
|
| 86 |
+
#include <ATen/ops/_cufft_get_plan_cache_max_size_ops.h>
|
| 87 |
+
#include <ATen/ops/_cufft_get_plan_cache_size_ops.h>
|
| 88 |
+
#include <ATen/ops/_cufft_set_plan_cache_max_size_ops.h>
|
| 89 |
+
#include <ATen/ops/_cummax_helper_ops.h>
|
| 90 |
+
#include <ATen/ops/_cummin_helper_ops.h>
|
| 91 |
+
#include <ATen/ops/_debug_has_internal_overlap_ops.h>
|
| 92 |
+
#include <ATen/ops/_dimI_ops.h>
|
| 93 |
+
#include <ATen/ops/_dimV_ops.h>
|
| 94 |
+
#include <ATen/ops/_dim_arange_ops.h>
|
| 95 |
+
#include <ATen/ops/_dirichlet_grad_ops.h>
|
| 96 |
+
#include <ATen/ops/_efficient_attention_backward_ops.h>
|
| 97 |
+
#include <ATen/ops/_efficient_attention_forward_ops.h>
|
| 98 |
+
#include <ATen/ops/_efficientzerotensor_ops.h>
|
| 99 |
+
#include <ATen/ops/_embedding_bag_ops.h>
|
| 100 |
+
#include <ATen/ops/_embedding_bag_backward_ops.h>
|
| 101 |
+
#include <ATen/ops/_embedding_bag_dense_backward_ops.h>
|
| 102 |
+
#include <ATen/ops/_embedding_bag_forward_only_ops.h>
|
| 103 |
+
#include <ATen/ops/_embedding_bag_per_sample_weights_backward_ops.h>
|
| 104 |
+
#include <ATen/ops/_embedding_bag_sparse_backward_ops.h>
|
| 105 |
+
#include <ATen/ops/_empty_affine_quantized_ops.h>
|
| 106 |
+
#include <ATen/ops/_empty_per_channel_affine_quantized_ops.h>
|
| 107 |
+
#include <ATen/ops/_euclidean_dist_ops.h>
|
| 108 |
+
#include <ATen/ops/_fake_quantize_learnable_per_channel_affine_ops.h>
|
| 109 |
+
#include <ATen/ops/_fake_quantize_learnable_per_channel_affine_backward_ops.h>
|
| 110 |
+
#include <ATen/ops/_fake_quantize_learnable_per_tensor_affine_ops.h>
|
| 111 |
+
#include <ATen/ops/_fake_quantize_learnable_per_tensor_affine_backward_ops.h>
|
| 112 |
+
#include <ATen/ops/_fake_quantize_per_tensor_affine_cachemask_tensor_qparams_ops.h>
|
| 113 |
+
#include <ATen/ops/_fft_c2c_ops.h>
|
| 114 |
+
#include <ATen/ops/_fft_c2r_ops.h>
|
| 115 |
+
#include <ATen/ops/_fft_r2c_ops.h>
|
| 116 |
+
#include <ATen/ops/_fill_mem_eff_dropout_mask_ops.h>
|
| 117 |
+
#include <ATen/ops/_flash_attention_backward_ops.h>
|
| 118 |
+
#include <ATen/ops/_flash_attention_forward_ops.h>
|
| 119 |
+
#include <ATen/ops/_foobar_ops.h>
|
| 120 |
+
#include <ATen/ops/_foreach_abs_ops.h>
|
| 121 |
+
#include <ATen/ops/_foreach_acos_ops.h>
|
| 122 |
+
#include <ATen/ops/_foreach_add_ops.h>
|
| 123 |
+
#include <ATen/ops/_foreach_addcdiv_ops.h>
|
| 124 |
+
#include <ATen/ops/_foreach_addcmul_ops.h>
|
| 125 |
+
#include <ATen/ops/_foreach_asin_ops.h>
|
| 126 |
+
#include <ATen/ops/_foreach_atan_ops.h>
|
| 127 |
+
#include <ATen/ops/_foreach_ceil_ops.h>
|
| 128 |
+
#include <ATen/ops/_foreach_clamp_max_ops.h>
|
| 129 |
+
#include <ATen/ops/_foreach_clamp_min_ops.h>
|
| 130 |
+
#include <ATen/ops/_foreach_copy_ops.h>
|
| 131 |
+
#include <ATen/ops/_foreach_cos_ops.h>
|
| 132 |
+
#include <ATen/ops/_foreach_cosh_ops.h>
|
| 133 |
+
#include <ATen/ops/_foreach_div_ops.h>
|
| 134 |
+
#include <ATen/ops/_foreach_erf_ops.h>
|
| 135 |
+
#include <ATen/ops/_foreach_erfc_ops.h>
|
| 136 |
+
#include <ATen/ops/_foreach_exp_ops.h>
|
| 137 |
+
#include <ATen/ops/_foreach_expm1_ops.h>
|
| 138 |
+
#include <ATen/ops/_foreach_floor_ops.h>
|
| 139 |
+
#include <ATen/ops/_foreach_frac_ops.h>
|
| 140 |
+
#include <ATen/ops/_foreach_lerp_ops.h>
|
| 141 |
+
#include <ATen/ops/_foreach_lgamma_ops.h>
|
| 142 |
+
#include <ATen/ops/_foreach_log_ops.h>
|
| 143 |
+
#include <ATen/ops/_foreach_log10_ops.h>
|
| 144 |
+
#include <ATen/ops/_foreach_log1p_ops.h>
|
| 145 |
+
#include <ATen/ops/_foreach_log2_ops.h>
|
| 146 |
+
#include <ATen/ops/_foreach_maximum_ops.h>
|
| 147 |
+
#include <ATen/ops/_foreach_minimum_ops.h>
|
| 148 |
+
#include <ATen/ops/_foreach_mul_ops.h>
|
| 149 |
+
#include <ATen/ops/_foreach_neg_ops.h>
|
| 150 |
+
#include <ATen/ops/_foreach_norm_ops.h>
|
| 151 |
+
#include <ATen/ops/_foreach_pow_ops.h>
|
| 152 |
+
#include <ATen/ops/_foreach_reciprocal_ops.h>
|
| 153 |
+
#include <ATen/ops/_foreach_round_ops.h>
|
| 154 |
+
#include <ATen/ops/_foreach_sigmoid_ops.h>
|
| 155 |
+
#include <ATen/ops/_foreach_sign_ops.h>
|
| 156 |
+
#include <ATen/ops/_foreach_sin_ops.h>
|
| 157 |
+
#include <ATen/ops/_foreach_sinh_ops.h>
|
| 158 |
+
#include <ATen/ops/_foreach_sqrt_ops.h>
|
| 159 |
+
#include <ATen/ops/_foreach_sub_ops.h>
|
| 160 |
+
#include <ATen/ops/_foreach_tan_ops.h>
|
| 161 |
+
#include <ATen/ops/_foreach_tanh_ops.h>
|
| 162 |
+
#include <ATen/ops/_foreach_trunc_ops.h>
|
| 163 |
+
#include <ATen/ops/_foreach_zero_ops.h>
|
| 164 |
+
#include <ATen/ops/_functional_assert_async_ops.h>
|
| 165 |
+
#include <ATen/ops/_functional_assert_scalar_ops.h>
|
| 166 |
+
#include <ATen/ops/_functional_sym_constrain_range_ops.h>
|
| 167 |
+
#include <ATen/ops/_functional_sym_constrain_range_for_size_ops.h>
|
| 168 |
+
#include <ATen/ops/_fused_adam_ops.h>
|
| 169 |
+
#include <ATen/ops/_fused_adamw_ops.h>
|
| 170 |
+
#include <ATen/ops/_fused_dropout_ops.h>
|
| 171 |
+
#include <ATen/ops/_fused_moving_avg_obs_fq_helper_ops.h>
|
| 172 |
+
#include <ATen/ops/_fused_sdp_choice_ops.h>
|
| 173 |
+
#include <ATen/ops/_fused_sgd_ops.h>
|
| 174 |
+
#include <ATen/ops/_fw_primal_ops.h>
|
| 175 |
+
#include <ATen/ops/_fw_primal_copy_ops.h>
|
| 176 |
+
#include <ATen/ops/_gather_sparse_backward_ops.h>
|
| 177 |
+
#include <ATen/ops/_grid_sampler_2d_cpu_fallback_ops.h>
|
| 178 |
+
#include <ATen/ops/_grid_sampler_2d_cpu_fallback_backward_ops.h>
|
| 179 |
+
#include <ATen/ops/_has_compatible_shallow_copy_type_ops.h>
|
| 180 |
+
#include <ATen/ops/_has_same_storage_numel_ops.h>
|
| 181 |
+
#include <ATen/ops/_histogramdd_bin_edges_ops.h>
|
| 182 |
+
#include <ATen/ops/_histogramdd_from_bin_cts_ops.h>
|
| 183 |
+
#include <ATen/ops/_histogramdd_from_bin_tensors_ops.h>
|
| 184 |
+
#include <ATen/ops/_index_put_impl_ops.h>
|
| 185 |
+
#include <ATen/ops/_indices_ops.h>
|
| 186 |
+
#include <ATen/ops/_indices_copy_ops.h>
|
| 187 |
+
#include <ATen/ops/_int_mm_ops.h>
|
| 188 |
+
#include <ATen/ops/_is_all_true_ops.h>
|
| 189 |
+
#include <ATen/ops/_is_any_true_ops.h>
|
| 190 |
+
#include <ATen/ops/_is_zerotensor_ops.h>
|
| 191 |
+
#include <ATen/ops/_lazy_clone_ops.h>
|
| 192 |
+
#include <ATen/ops/_linalg_check_errors_ops.h>
|
| 193 |
+
#include <ATen/ops/_linalg_det_ops.h>
|
| 194 |
+
#include <ATen/ops/_linalg_eigh_ops.h>
|
| 195 |
+
#include <ATen/ops/_linalg_eigvals_ops.h>
|
| 196 |
+
#include <ATen/ops/_linalg_slogdet_ops.h>
|
| 197 |
+
#include <ATen/ops/_linalg_solve_ex_ops.h>
|
| 198 |
+
#include <ATen/ops/_linalg_svd_ops.h>
|
| 199 |
+
#include <ATen/ops/_local_scalar_dense_ops.h>
|
| 200 |
+
#include <ATen/ops/_log_softmax_ops.h>
|
| 201 |
+
#include <ATen/ops/_log_softmax_backward_data_ops.h>
|
| 202 |
+
#include <ATen/ops/_logcumsumexp_ops.h>
|
| 203 |
+
#include <ATen/ops/_lstm_mps_ops.h>
|
| 204 |
+
#include <ATen/ops/_lu_with_info_ops.h>
|
| 205 |
+
#include <ATen/ops/_make_dep_token_ops.h>
|
| 206 |
+
#include <ATen/ops/_make_dual_ops.h>
|
| 207 |
+
#include <ATen/ops/_make_dual_copy_ops.h>
|
| 208 |
+
#include <ATen/ops/_make_per_channel_quantized_tensor_ops.h>
|
| 209 |
+
#include <ATen/ops/_make_per_tensor_quantized_tensor_ops.h>
|
| 210 |
+
#include <ATen/ops/_masked_scale_ops.h>
|
| 211 |
+
#include <ATen/ops/_masked_softmax_ops.h>
|
| 212 |
+
#include <ATen/ops/_masked_softmax_backward_ops.h>
|
| 213 |
+
#include <ATen/ops/_mixed_dtypes_linear_ops.h>
|
| 214 |
+
#include <ATen/ops/_mkldnn_reshape_ops.h>
|
| 215 |
+
#include <ATen/ops/_mkldnn_transpose_ops.h>
|
| 216 |
+
#include <ATen/ops/_mps_convolution_ops.h>
|
| 217 |
+
#include <ATen/ops/_mps_convolution_transpose_ops.h>
|
| 218 |
+
#include <ATen/ops/_native_batch_norm_legit_ops.h>
|
| 219 |
+
#include <ATen/ops/_native_batch_norm_legit_no_training_ops.h>
|
| 220 |
+
#include <ATen/ops/_native_multi_head_attention_ops.h>
|
| 221 |
+
#include <ATen/ops/_neg_view_ops.h>
|
| 222 |
+
#include <ATen/ops/_neg_view_copy_ops.h>
|
| 223 |
+
#include <ATen/ops/_nested_from_padded_ops.h>
|
| 224 |
+
#include <ATen/ops/_nested_from_padded_and_nested_example_ops.h>
|
| 225 |
+
#include <ATen/ops/_nested_get_jagged_dummy_ops.h>
|
| 226 |
+
#include <ATen/ops/_nested_get_lengths_ops.h>
|
| 227 |
+
#include <ATen/ops/_nested_get_offsets_ops.h>
|
| 228 |
+
#include <ATen/ops/_nested_get_ragged_idx_ops.h>
|
| 229 |
+
#include <ATen/ops/_nested_get_values_ops.h>
|
| 230 |
+
#include <ATen/ops/_nested_get_values_copy_ops.h>
|
| 231 |
+
#include <ATen/ops/_nested_select_backward_ops.h>
|
| 232 |
+
#include <ATen/ops/_nested_sum_backward_ops.h>
|
| 233 |
+
#include <ATen/ops/_nested_tensor_from_mask_ops.h>
|
| 234 |
+
#include <ATen/ops/_nested_tensor_from_mask_left_aligned_ops.h>
|
| 235 |
+
#include <ATen/ops/_nested_tensor_from_tensor_list_ops.h>
|
| 236 |
+
#include <ATen/ops/_nested_tensor_size_ops.h>
|
| 237 |
+
#include <ATen/ops/_nested_tensor_softmax_with_shape_ops.h>
|
| 238 |
+
#include <ATen/ops/_nested_tensor_storage_offsets_ops.h>
|
| 239 |
+
#include <ATen/ops/_nested_tensor_strides_ops.h>
|
| 240 |
+
#include <ATen/ops/_nested_view_from_buffer_ops.h>
|
| 241 |
+
#include <ATen/ops/_nested_view_from_buffer_copy_ops.h>
|
| 242 |
+
#include <ATen/ops/_nested_view_from_jagged_ops.h>
|
| 243 |
+
#include <ATen/ops/_nested_view_from_jagged_copy_ops.h>
|
| 244 |
+
#include <ATen/ops/_new_zeros_with_same_feature_meta_ops.h>
|
| 245 |
+
#include <ATen/ops/_nnpack_available_ops.h>
|
| 246 |
+
#include <ATen/ops/_nnpack_spatial_convolution_ops.h>
|
| 247 |
+
#include <ATen/ops/_nnz_ops.h>
|
| 248 |
+
#include <ATen/ops/_pack_padded_sequence_ops.h>
|
| 249 |
+
#include <ATen/ops/_pack_padded_sequence_backward_ops.h>
|
| 250 |
+
#include <ATen/ops/_pad_circular_ops.h>
|
| 251 |
+
#include <ATen/ops/_pad_enum_ops.h>
|
| 252 |
+
#include <ATen/ops/_pad_packed_sequence_ops.h>
|
| 253 |
+
#include <ATen/ops/_pdist_backward_ops.h>
|
| 254 |
+
#include <ATen/ops/_pdist_forward_ops.h>
|
| 255 |
+
#include <ATen/ops/_pin_memory_ops.h>
|
| 256 |
+
#include <ATen/ops/_prelu_kernel_ops.h>
|
| 257 |
+
#include <ATen/ops/_prelu_kernel_backward_ops.h>
|
| 258 |
+
#include <ATen/ops/_print_ops.h>
|
| 259 |
+
#include <ATen/ops/_propagate_xla_data_ops.h>
|
| 260 |
+
#include <ATen/ops/_remove_batch_dim_ops.h>
|
| 261 |
+
#include <ATen/ops/_reshape_alias_ops.h>
|
| 262 |
+
#include <ATen/ops/_reshape_alias_copy_ops.h>
|
| 263 |
+
#include <ATen/ops/_reshape_copy_ops.h>
|
| 264 |
+
#include <ATen/ops/_reshape_from_tensor_ops.h>
|
| 265 |
+
#include <ATen/ops/_resize_output_ops.h>
|
| 266 |
+
#include <ATen/ops/_rowwise_prune_ops.h>
|
| 267 |
+
#include <ATen/ops/_sample_dirichlet_ops.h>
|
| 268 |
+
#include <ATen/ops/_saturate_weight_to_fp16_ops.h>
|
| 269 |
+
#include <ATen/ops/_scaled_dot_product_attention_math_ops.h>
|
| 270 |
+
#include <ATen/ops/_scaled_dot_product_cudnn_attention_ops.h>
|
| 271 |
+
#include <ATen/ops/_scaled_dot_product_efficient_attention_ops.h>
|
| 272 |
+
#include <ATen/ops/_scaled_dot_product_efficient_attention_backward_ops.h>
|
| 273 |
+
#include <ATen/ops/_scaled_dot_product_flash_attention_ops.h>
|
| 274 |
+
#include <ATen/ops/_scaled_dot_product_flash_attention_backward_ops.h>
|
| 275 |
+
#include <ATen/ops/_scaled_dot_product_flash_attention_for_cpu_ops.h>
|
| 276 |
+
#include <ATen/ops/_scaled_dot_product_flash_attention_for_cpu_backward_ops.h>
|
| 277 |
+
#include <ATen/ops/_scaled_mm_ops.h>
|
| 278 |
+
#include <ATen/ops/_segment_reduce_backward_ops.h>
|
| 279 |
+
#include <ATen/ops/_shape_as_tensor_ops.h>
|
| 280 |
+
#include <ATen/ops/_slow_conv2d_backward_ops.h>
|
| 281 |
+
#include <ATen/ops/_slow_conv2d_forward_ops.h>
|
| 282 |
+
#include <ATen/ops/_sobol_engine_draw_ops.h>
|
| 283 |
+
#include <ATen/ops/_sobol_engine_ff_ops.h>
|
| 284 |
+
#include <ATen/ops/_sobol_engine_initialize_state_ops.h>
|
| 285 |
+
#include <ATen/ops/_sobol_engine_scramble_ops.h>
|
| 286 |
+
#include <ATen/ops/_softmax_ops.h>
|
| 287 |
+
#include <ATen/ops/_softmax_backward_data_ops.h>
|
| 288 |
+
#include <ATen/ops/_sparse_addmm_ops.h>
|
| 289 |
+
#include <ATen/ops/_sparse_broadcast_to_ops.h>
|
| 290 |
+
#include <ATen/ops/_sparse_broadcast_to_copy_ops.h>
|
| 291 |
+
#include <ATen/ops/_sparse_bsc_tensor_unsafe_ops.h>
|
| 292 |
+
#include <ATen/ops/_sparse_bsr_tensor_unsafe_ops.h>
|
| 293 |
+
#include <ATen/ops/_sparse_compressed_tensor_unsafe_ops.h>
|
| 294 |
+
#include <ATen/ops/_sparse_coo_tensor_unsafe_ops.h>
|
| 295 |
+
#include <ATen/ops/_sparse_coo_tensor_with_dims_ops.h>
|
| 296 |
+
#include <ATen/ops/_sparse_coo_tensor_with_dims_and_tensors_ops.h>
|
| 297 |
+
#include <ATen/ops/_sparse_csc_tensor_unsafe_ops.h>
|
| 298 |
+
#include <ATen/ops/_sparse_csr_prod_ops.h>
|
| 299 |
+
#include <ATen/ops/_sparse_csr_sum_ops.h>
|
| 300 |
+
#include <ATen/ops/_sparse_csr_tensor_unsafe_ops.h>
|
| 301 |
+
#include <ATen/ops/_sparse_log_softmax_ops.h>
|
| 302 |
+
#include <ATen/ops/_sparse_log_softmax_backward_data_ops.h>
|
| 303 |
+
#include <ATen/ops/_sparse_mask_projection_ops.h>
|
| 304 |
+
#include <ATen/ops/_sparse_mm_ops.h>
|
| 305 |
+
#include <ATen/ops/_sparse_mm_reduce_impl_ops.h>
|
| 306 |
+
#include <ATen/ops/_sparse_mm_reduce_impl_backward_ops.h>
|
| 307 |
+
#include <ATen/ops/_sparse_semi_structured_linear_ops.h>
|
| 308 |
+
#include <ATen/ops/_sparse_softmax_ops.h>
|
| 309 |
+
#include <ATen/ops/_sparse_softmax_backward_data_ops.h>
|
| 310 |
+
#include <ATen/ops/_sparse_sparse_matmul_ops.h>
|
| 311 |
+
#include <ATen/ops/_sparse_sum_ops.h>
|
| 312 |
+
#include <ATen/ops/_sparse_sum_backward_ops.h>
|
| 313 |
+
#include <ATen/ops/_spdiags_ops.h>
|
| 314 |
+
#include <ATen/ops/_stack_ops.h>
|
| 315 |
+
#include <ATen/ops/_standard_gamma_ops.h>
|
| 316 |
+
#include <ATen/ops/_standard_gamma_grad_ops.h>
|
| 317 |
+
#include <ATen/ops/_test_ambiguous_defaults_ops.h>
|
| 318 |
+
#include <ATen/ops/_test_autograd_multiple_dispatch_ops.h>
|
| 319 |
+
#include <ATen/ops/_test_autograd_multiple_dispatch_view_ops.h>
|
| 320 |
+
#include <ATen/ops/_test_autograd_multiple_dispatch_view_copy_ops.h>
|
| 321 |
+
#include <ATen/ops/_test_check_tensor_ops.h>
|
| 322 |
+
#include <ATen/ops/_test_functorch_fallback_ops.h>
|
| 323 |
+
#include <ATen/ops/_test_optional_filled_intlist_ops.h>
|
| 324 |
+
#include <ATen/ops/_test_optional_floatlist_ops.h>
|
| 325 |
+
#include <ATen/ops/_test_optional_intlist_ops.h>
|
| 326 |
+
#include <ATen/ops/_test_parallel_materialize_ops.h>
|
| 327 |
+
#include <ATen/ops/_test_serialization_subcmul_ops.h>
|
| 328 |
+
#include <ATen/ops/_test_string_default_ops.h>
|
| 329 |
+
#include <ATen/ops/_test_warn_in_autograd_ops.h>
|
| 330 |
+
#include <ATen/ops/_thnn_differentiable_gru_cell_backward_ops.h>
|
| 331 |
+
#include <ATen/ops/_thnn_differentiable_lstm_cell_backward_ops.h>
|
| 332 |
+
#include <ATen/ops/_thnn_fused_gru_cell_ops.h>
|
| 333 |
+
#include <ATen/ops/_thnn_fused_gru_cell_backward_ops.h>
|
| 334 |
+
#include <ATen/ops/_thnn_fused_lstm_cell_ops.h>
|
| 335 |
+
#include <ATen/ops/_thnn_fused_lstm_cell_backward_ops.h>
|
| 336 |
+
#include <ATen/ops/_thnn_fused_lstm_cell_backward_impl_ops.h>
|
| 337 |
+
#include <ATen/ops/_to_copy_ops.h>
|
| 338 |
+
#include <ATen/ops/_to_cpu_ops.h>
|
| 339 |
+
#include <ATen/ops/_to_dense_ops.h>
|
| 340 |
+
#include <ATen/ops/_to_sparse_ops.h>
|
| 341 |
+
#include <ATen/ops/_to_sparse_bsc_ops.h>
|
| 342 |
+
#include <ATen/ops/_to_sparse_bsr_ops.h>
|
| 343 |
+
#include <ATen/ops/_to_sparse_csc_ops.h>
|
| 344 |
+
#include <ATen/ops/_to_sparse_csr_ops.h>
|
| 345 |
+
#include <ATen/ops/_to_sparse_semi_structured_ops.h>
|
| 346 |
+
#include <ATen/ops/_transform_bias_rescale_qkv_ops.h>
|
| 347 |
+
#include <ATen/ops/_transformer_encoder_layer_fwd_ops.h>
|
| 348 |
+
#include <ATen/ops/_trilinear_ops.h>
|
| 349 |
+
#include <ATen/ops/_triton_multi_head_attention_ops.h>
|
| 350 |
+
#include <ATen/ops/_triton_scaled_dot_attention_ops.h>
|
| 351 |
+
#include <ATen/ops/_unique_ops.h>
|
| 352 |
+
#include <ATen/ops/_unique2_ops.h>
|
| 353 |
+
#include <ATen/ops/_unpack_dual_ops.h>
|
| 354 |
+
#include <ATen/ops/_unsafe_index_ops.h>
|
| 355 |
+
#include <ATen/ops/_unsafe_index_put_ops.h>
|
| 356 |
+
#include <ATen/ops/_unsafe_view_ops.h>
|
| 357 |
+
#include <ATen/ops/_upsample_bicubic2d_aa_ops.h>
|
| 358 |
+
#include <ATen/ops/_upsample_bicubic2d_aa_backward_ops.h>
|
| 359 |
+
#include <ATen/ops/_upsample_bilinear2d_aa_ops.h>
|
| 360 |
+
#include <ATen/ops/_upsample_bilinear2d_aa_backward_ops.h>
|
| 361 |
+
#include <ATen/ops/_upsample_nearest_exact1d_ops.h>
|
| 362 |
+
#include <ATen/ops/_upsample_nearest_exact1d_backward_ops.h>
|
| 363 |
+
#include <ATen/ops/_upsample_nearest_exact2d_ops.h>
|
| 364 |
+
#include <ATen/ops/_upsample_nearest_exact2d_backward_ops.h>
|
| 365 |
+
#include <ATen/ops/_upsample_nearest_exact3d_ops.h>
|
| 366 |
+
#include <ATen/ops/_upsample_nearest_exact3d_backward_ops.h>
|
| 367 |
+
#include <ATen/ops/_use_cudnn_ctc_loss_ops.h>
|
| 368 |
+
#include <ATen/ops/_use_cudnn_rnn_flatten_weight_ops.h>
|
| 369 |
+
#include <ATen/ops/_validate_compressed_sparse_indices_ops.h>
|
| 370 |
+
#include <ATen/ops/_validate_sparse_bsc_tensor_args_ops.h>
|
| 371 |
+
#include <ATen/ops/_validate_sparse_bsr_tensor_args_ops.h>
|
| 372 |
+
#include <ATen/ops/_validate_sparse_compressed_tensor_args_ops.h>
|
| 373 |
+
#include <ATen/ops/_validate_sparse_coo_tensor_args_ops.h>
|
| 374 |
+
#include <ATen/ops/_validate_sparse_csc_tensor_args_ops.h>
|
| 375 |
+
#include <ATen/ops/_validate_sparse_csr_tensor_args_ops.h>
|
| 376 |
+
#include <ATen/ops/_values_ops.h>
|
| 377 |
+
#include <ATen/ops/_values_copy_ops.h>
|
| 378 |
+
#include <ATen/ops/_version_ops.h>
|
| 379 |
+
#include <ATen/ops/_weight_int4pack_mm_ops.h>
|
| 380 |
+
#include <ATen/ops/_weight_int8pack_mm_ops.h>
|
| 381 |
+
#include <ATen/ops/_weight_norm_ops.h>
|
| 382 |
+
#include <ATen/ops/_weight_norm_differentiable_backward_ops.h>
|
| 383 |
+
#include <ATen/ops/_weight_norm_interface_ops.h>
|
| 384 |
+
#include <ATen/ops/_weight_norm_interface_backward_ops.h>
|
| 385 |
+
#include <ATen/ops/abs_ops.h>
|
| 386 |
+
#include <ATen/ops/absolute_ops.h>
|
| 387 |
+
#include <ATen/ops/acos_ops.h>
|
| 388 |
+
#include <ATen/ops/acosh_ops.h>
|
| 389 |
+
#include <ATen/ops/adaptive_avg_pool1d_ops.h>
|
| 390 |
+
#include <ATen/ops/adaptive_avg_pool2d_ops.h>
|
| 391 |
+
#include <ATen/ops/adaptive_avg_pool3d_ops.h>
|
| 392 |
+
#include <ATen/ops/adaptive_avg_pool3d_backward_ops.h>
|
| 393 |
+
#include <ATen/ops/adaptive_max_pool1d_ops.h>
|
| 394 |
+
#include <ATen/ops/adaptive_max_pool2d_ops.h>
|
| 395 |
+
#include <ATen/ops/adaptive_max_pool2d_backward_ops.h>
|
| 396 |
+
#include <ATen/ops/adaptive_max_pool3d_ops.h>
|
| 397 |
+
#include <ATen/ops/adaptive_max_pool3d_backward_ops.h>
|
| 398 |
+
#include <ATen/ops/add_ops.h>
|
| 399 |
+
#include <ATen/ops/addbmm_ops.h>
|
| 400 |
+
#include <ATen/ops/addcdiv_ops.h>
|
| 401 |
+
#include <ATen/ops/addcmul_ops.h>
|
| 402 |
+
#include <ATen/ops/addmm_ops.h>
|
| 403 |
+
#include <ATen/ops/addmv_ops.h>
|
| 404 |
+
#include <ATen/ops/addr_ops.h>
|
| 405 |
+
#include <ATen/ops/adjoint_ops.h>
|
| 406 |
+
#include <ATen/ops/affine_grid_generator_ops.h>
|
| 407 |
+
#include <ATen/ops/affine_grid_generator_backward_ops.h>
|
| 408 |
+
#include <ATen/ops/alias_ops.h>
|
| 409 |
+
#include <ATen/ops/alias_copy_ops.h>
|
| 410 |
+
#include <ATen/ops/align_as_ops.h>
|
| 411 |
+
#include <ATen/ops/align_tensors_ops.h>
|
| 412 |
+
#include <ATen/ops/align_to_ops.h>
|
| 413 |
+
#include <ATen/ops/all_ops.h>
|
| 414 |
+
#include <ATen/ops/allclose_ops.h>
|
| 415 |
+
#include <ATen/ops/alpha_dropout_ops.h>
|
| 416 |
+
#include <ATen/ops/amax_ops.h>
|
| 417 |
+
#include <ATen/ops/amin_ops.h>
|
| 418 |
+
#include <ATen/ops/aminmax_ops.h>
|
| 419 |
+
#include <ATen/ops/and_ops.h>
|
| 420 |
+
#include <ATen/ops/angle_ops.h>
|
| 421 |
+
#include <ATen/ops/any_ops.h>
|
| 422 |
+
#include <ATen/ops/arange_ops.h>
|
| 423 |
+
#include <ATen/ops/arccos_ops.h>
|
| 424 |
+
#include <ATen/ops/arccosh_ops.h>
|
| 425 |
+
#include <ATen/ops/arcsin_ops.h>
|
| 426 |
+
#include <ATen/ops/arcsinh_ops.h>
|
| 427 |
+
#include <ATen/ops/arctan_ops.h>
|
| 428 |
+
#include <ATen/ops/arctan2_ops.h>
|
| 429 |
+
#include <ATen/ops/arctanh_ops.h>
|
| 430 |
+
#include <ATen/ops/argmax_ops.h>
|
| 431 |
+
#include <ATen/ops/argmin_ops.h>
|
| 432 |
+
#include <ATen/ops/argsort_ops.h>
|
| 433 |
+
#include <ATen/ops/argwhere_ops.h>
|
| 434 |
+
#include <ATen/ops/as_strided_ops.h>
|
| 435 |
+
#include <ATen/ops/as_strided_copy_ops.h>
|
| 436 |
+
#include <ATen/ops/as_strided_scatter_ops.h>
|
| 437 |
+
#include <ATen/ops/asin_ops.h>
|
| 438 |
+
#include <ATen/ops/asinh_ops.h>
|
| 439 |
+
#include <ATen/ops/atan_ops.h>
|
| 440 |
+
#include <ATen/ops/atan2_ops.h>
|
| 441 |
+
#include <ATen/ops/atanh_ops.h>
|
| 442 |
+
#include <ATen/ops/atleast_1d_ops.h>
|
| 443 |
+
#include <ATen/ops/atleast_2d_ops.h>
|
| 444 |
+
#include <ATen/ops/atleast_3d_ops.h>
|
| 445 |
+
#include <ATen/ops/avg_pool1d_ops.h>
|
| 446 |
+
#include <ATen/ops/avg_pool2d_ops.h>
|
| 447 |
+
#include <ATen/ops/avg_pool2d_backward_ops.h>
|
| 448 |
+
#include <ATen/ops/avg_pool3d_ops.h>
|
| 449 |
+
#include <ATen/ops/avg_pool3d_backward_ops.h>
|
| 450 |
+
#include <ATen/ops/baddbmm_ops.h>
|
| 451 |
+
#include <ATen/ops/bartlett_window_ops.h>
|
| 452 |
+
#include <ATen/ops/batch_norm_ops.h>
|
| 453 |
+
#include <ATen/ops/batch_norm_backward_elemt_ops.h>
|
| 454 |
+
#include <ATen/ops/batch_norm_backward_reduce_ops.h>
|
| 455 |
+
#include <ATen/ops/batch_norm_elemt_ops.h>
|
| 456 |
+
#include <ATen/ops/batch_norm_gather_stats_ops.h>
|
| 457 |
+
#include <ATen/ops/batch_norm_gather_stats_with_counts_ops.h>
|
| 458 |
+
#include <ATen/ops/batch_norm_stats_ops.h>
|
| 459 |
+
#include <ATen/ops/batch_norm_update_stats_ops.h>
|
| 460 |
+
#include <ATen/ops/bernoulli_ops.h>
|
| 461 |
+
#include <ATen/ops/bilinear_ops.h>
|
| 462 |
+
#include <ATen/ops/binary_cross_entropy_ops.h>
|
| 463 |
+
#include <ATen/ops/binary_cross_entropy_backward_ops.h>
|
| 464 |
+
#include <ATen/ops/binary_cross_entropy_with_logits_ops.h>
|
| 465 |
+
#include <ATen/ops/bincount_ops.h>
|
| 466 |
+
#include <ATen/ops/binomial_ops.h>
|
| 467 |
+
#include <ATen/ops/bitwise_and_ops.h>
|
| 468 |
+
#include <ATen/ops/bitwise_left_shift_ops.h>
|
| 469 |
+
#include <ATen/ops/bitwise_not_ops.h>
|
| 470 |
+
#include <ATen/ops/bitwise_or_ops.h>
|
| 471 |
+
#include <ATen/ops/bitwise_right_shift_ops.h>
|
| 472 |
+
#include <ATen/ops/bitwise_xor_ops.h>
|
| 473 |
+
#include <ATen/ops/blackman_window_ops.h>
|
| 474 |
+
#include <ATen/ops/block_diag_ops.h>
|
| 475 |
+
#include <ATen/ops/bmm_ops.h>
|
| 476 |
+
#include <ATen/ops/broadcast_tensors_ops.h>
|
| 477 |
+
#include <ATen/ops/broadcast_to_ops.h>
|
| 478 |
+
#include <ATen/ops/bucketize_ops.h>
|
| 479 |
+
#include <ATen/ops/can_cast_ops.h>
|
| 480 |
+
#include <ATen/ops/cartesian_prod_ops.h>
|
| 481 |
+
#include <ATen/ops/cat_ops.h>
|
| 482 |
+
#include <ATen/ops/cauchy_ops.h>
|
| 483 |
+
#include <ATen/ops/ccol_indices_ops.h>
|
| 484 |
+
#include <ATen/ops/ccol_indices_copy_ops.h>
|
| 485 |
+
#include <ATen/ops/cdist_ops.h>
|
| 486 |
+
#include <ATen/ops/ceil_ops.h>
|
| 487 |
+
#include <ATen/ops/celu_ops.h>
|
| 488 |
+
#include <ATen/ops/chain_matmul_ops.h>
|
| 489 |
+
#include <ATen/ops/chalf_ops.h>
|
| 490 |
+
#include <ATen/ops/channel_shuffle_ops.h>
|
| 491 |
+
#include <ATen/ops/cholesky_ops.h>
|
| 492 |
+
#include <ATen/ops/cholesky_inverse_ops.h>
|
| 493 |
+
#include <ATen/ops/cholesky_solve_ops.h>
|
| 494 |
+
#include <ATen/ops/choose_qparams_optimized_ops.h>
|
| 495 |
+
#include <ATen/ops/chunk_ops.h>
|
| 496 |
+
#include <ATen/ops/clamp_ops.h>
|
| 497 |
+
#include <ATen/ops/clamp_max_ops.h>
|
| 498 |
+
#include <ATen/ops/clamp_min_ops.h>
|
| 499 |
+
#include <ATen/ops/clip_ops.h>
|
| 500 |
+
#include <ATen/ops/clone_ops.h>
|
| 501 |
+
#include <ATen/ops/coalesce_ops.h>
|
| 502 |
+
#include <ATen/ops/col2im_ops.h>
|
| 503 |
+
#include <ATen/ops/col_indices_ops.h>
|
| 504 |
+
#include <ATen/ops/col_indices_copy_ops.h>
|
| 505 |
+
#include <ATen/ops/column_stack_ops.h>
|
| 506 |
+
#include <ATen/ops/combinations_ops.h>
|
| 507 |
+
#include <ATen/ops/complex_ops.h>
|
| 508 |
+
#include <ATen/ops/concat_ops.h>
|
| 509 |
+
#include <ATen/ops/concatenate_ops.h>
|
| 510 |
+
#include <ATen/ops/conj_ops.h>
|
| 511 |
+
#include <ATen/ops/conj_physical_ops.h>
|
| 512 |
+
#include <ATen/ops/constant_pad_nd_ops.h>
|
| 513 |
+
#include <ATen/ops/contiguous_ops.h>
|
| 514 |
+
#include <ATen/ops/conv1d_ops.h>
|
| 515 |
+
#include <ATen/ops/conv2d_ops.h>
|
| 516 |
+
#include <ATen/ops/conv3d_ops.h>
|
| 517 |
+
#include <ATen/ops/conv_depthwise3d_ops.h>
|
| 518 |
+
#include <ATen/ops/conv_tbc_ops.h>
|
| 519 |
+
#include <ATen/ops/conv_tbc_backward_ops.h>
|
| 520 |
+
#include <ATen/ops/conv_transpose1d_ops.h>
|
| 521 |
+
#include <ATen/ops/conv_transpose2d_ops.h>
|
| 522 |
+
#include <ATen/ops/conv_transpose3d_ops.h>
|
| 523 |
+
#include <ATen/ops/convolution_ops.h>
|
| 524 |
+
#include <ATen/ops/convolution_backward_ops.h>
|
| 525 |
+
#include <ATen/ops/convolution_backward_overrideable_ops.h>
|
| 526 |
+
#include <ATen/ops/convolution_overrideable_ops.h>
|
| 527 |
+
#include <ATen/ops/copy_ops.h>
|
| 528 |
+
#include <ATen/ops/copy_sparse_to_sparse_ops.h>
|
| 529 |
+
#include <ATen/ops/copysign_ops.h>
|
| 530 |
+
#include <ATen/ops/corrcoef_ops.h>
|
| 531 |
+
#include <ATen/ops/cos_ops.h>
|
| 532 |
+
#include <ATen/ops/cosh_ops.h>
|
| 533 |
+
#include <ATen/ops/cosine_embedding_loss_ops.h>
|
| 534 |
+
#include <ATen/ops/cosine_similarity_ops.h>
|
| 535 |
+
#include <ATen/ops/count_nonzero_ops.h>
|
| 536 |
+
#include <ATen/ops/cov_ops.h>
|
| 537 |
+
#include <ATen/ops/cross_ops.h>
|
| 538 |
+
#include <ATen/ops/cross_entropy_loss_ops.h>
|
| 539 |
+
#include <ATen/ops/crow_indices_ops.h>
|
| 540 |
+
#include <ATen/ops/crow_indices_copy_ops.h>
|
| 541 |
+
#include <ATen/ops/ctc_loss_ops.h>
|
| 542 |
+
#include <ATen/ops/cudnn_affine_grid_generator_ops.h>
|
| 543 |
+
#include <ATen/ops/cudnn_affine_grid_generator_backward_ops.h>
|
| 544 |
+
#include <ATen/ops/cudnn_batch_norm_ops.h>
|
| 545 |
+
#include <ATen/ops/cudnn_batch_norm_backward_ops.h>
|
| 546 |
+
#include <ATen/ops/cudnn_convolution_ops.h>
|
| 547 |
+
#include <ATen/ops/cudnn_convolution_add_relu_ops.h>
|
| 548 |
+
#include <ATen/ops/cudnn_convolution_relu_ops.h>
|
| 549 |
+
#include <ATen/ops/cudnn_convolution_transpose_ops.h>
|
| 550 |
+
#include <ATen/ops/cudnn_grid_sampler_ops.h>
|
| 551 |
+
#include <ATen/ops/cudnn_grid_sampler_backward_ops.h>
|
| 552 |
+
#include <ATen/ops/cudnn_is_acceptable_ops.h>
|
| 553 |
+
#include <ATen/ops/cummax_ops.h>
|
| 554 |
+
#include <ATen/ops/cummaxmin_backward_ops.h>
|
| 555 |
+
#include <ATen/ops/cummin_ops.h>
|
| 556 |
+
#include <ATen/ops/cumprod_ops.h>
|
| 557 |
+
#include <ATen/ops/cumprod_backward_ops.h>
|
| 558 |
+
#include <ATen/ops/cumsum_ops.h>
|
| 559 |
+
#include <ATen/ops/cumulative_trapezoid_ops.h>
|
| 560 |
+
#include <ATen/ops/data_ops.h>
|
| 561 |
+
#include <ATen/ops/deg2rad_ops.h>
|
| 562 |
+
#include <ATen/ops/dense_dim_ops.h>
|
| 563 |
+
#include <ATen/ops/dequantize_ops.h>
|
| 564 |
+
#include <ATen/ops/det_ops.h>
|
| 565 |
+
#include <ATen/ops/detach_ops.h>
|
| 566 |
+
#include <ATen/ops/detach_copy_ops.h>
|
| 567 |
+
#include <ATen/ops/diag_ops.h>
|
| 568 |
+
#include <ATen/ops/diag_embed_ops.h>
|
| 569 |
+
#include <ATen/ops/diagflat_ops.h>
|
| 570 |
+
#include <ATen/ops/diagonal_ops.h>
|
| 571 |
+
#include <ATen/ops/diagonal_backward_ops.h>
|
| 572 |
+
#include <ATen/ops/diagonal_copy_ops.h>
|
| 573 |
+
#include <ATen/ops/diagonal_scatter_ops.h>
|
| 574 |
+
#include <ATen/ops/diff_ops.h>
|
| 575 |
+
#include <ATen/ops/digamma_ops.h>
|
| 576 |
+
#include <ATen/ops/dist_ops.h>
|
| 577 |
+
#include <ATen/ops/div_ops.h>
|
| 578 |
+
#include <ATen/ops/divide_ops.h>
|
| 579 |
+
#include <ATen/ops/dot_ops.h>
|
| 580 |
+
#include <ATen/ops/dropout_ops.h>
|
| 581 |
+
#include <ATen/ops/dsplit_ops.h>
|
| 582 |
+
#include <ATen/ops/dstack_ops.h>
|
| 583 |
+
#include <ATen/ops/einsum_ops.h>
|
| 584 |
+
#include <ATen/ops/elu_ops.h>
|
| 585 |
+
#include <ATen/ops/elu_backward_ops.h>
|
| 586 |
+
#include <ATen/ops/embedding_ops.h>
|
| 587 |
+
#include <ATen/ops/embedding_backward_ops.h>
|
| 588 |
+
#include <ATen/ops/embedding_bag_ops.h>
|
| 589 |
+
#include <ATen/ops/embedding_dense_backward_ops.h>
|
| 590 |
+
#include <ATen/ops/embedding_renorm_ops.h>
|
| 591 |
+
#include <ATen/ops/embedding_sparse_backward_ops.h>
|
| 592 |
+
#include <ATen/ops/empty_ops.h>
|
| 593 |
+
#include <ATen/ops/empty_like_ops.h>
|
| 594 |
+
#include <ATen/ops/empty_permuted_ops.h>
|
| 595 |
+
#include <ATen/ops/empty_quantized_ops.h>
|
| 596 |
+
#include <ATen/ops/empty_strided_ops.h>
|
| 597 |
+
#include <ATen/ops/eq_ops.h>
|
| 598 |
+
#include <ATen/ops/equal_ops.h>
|
| 599 |
+
#include <ATen/ops/erf_ops.h>
|
| 600 |
+
#include <ATen/ops/erfc_ops.h>
|
| 601 |
+
#include <ATen/ops/erfinv_ops.h>
|
| 602 |
+
#include <ATen/ops/exp_ops.h>
|
| 603 |
+
#include <ATen/ops/exp2_ops.h>
|
| 604 |
+
#include <ATen/ops/expand_ops.h>
|
| 605 |
+
#include <ATen/ops/expand_as_ops.h>
|
| 606 |
+
#include <ATen/ops/expand_copy_ops.h>
|
| 607 |
+
#include <ATen/ops/expm1_ops.h>
|
| 608 |
+
#include <ATen/ops/exponential_ops.h>
|
| 609 |
+
#include <ATen/ops/eye_ops.h>
|
| 610 |
+
#include <ATen/ops/fake_quantize_per_channel_affine_ops.h>
|
| 611 |
+
#include <ATen/ops/fake_quantize_per_channel_affine_cachemask_ops.h>
|
| 612 |
+
#include <ATen/ops/fake_quantize_per_channel_affine_cachemask_backward_ops.h>
|
| 613 |
+
#include <ATen/ops/fake_quantize_per_tensor_affine_ops.h>
|
| 614 |
+
#include <ATen/ops/fake_quantize_per_tensor_affine_cachemask_ops.h>
|
| 615 |
+
#include <ATen/ops/fake_quantize_per_tensor_affine_cachemask_backward_ops.h>
|
| 616 |
+
#include <ATen/ops/fbgemm_linear_fp16_weight_ops.h>
|
| 617 |
+
#include <ATen/ops/fbgemm_linear_fp16_weight_fp32_activation_ops.h>
|
| 618 |
+
#include <ATen/ops/fbgemm_linear_int8_weight_ops.h>
|
| 619 |
+
#include <ATen/ops/fbgemm_linear_int8_weight_fp32_activation_ops.h>
|
| 620 |
+
#include <ATen/ops/fbgemm_linear_quantize_weight_ops.h>
|
| 621 |
+
#include <ATen/ops/fbgemm_pack_gemm_matrix_fp16_ops.h>
|
| 622 |
+
#include <ATen/ops/fbgemm_pack_quantized_matrix_ops.h>
|
| 623 |
+
#include <ATen/ops/feature_alpha_dropout_ops.h>
|
| 624 |
+
#include <ATen/ops/feature_dropout_ops.h>
|
| 625 |
+
#include <ATen/ops/fft_fft_ops.h>
|
| 626 |
+
#include <ATen/ops/fft_fft2_ops.h>
|
| 627 |
+
#include <ATen/ops/fft_fftfreq_ops.h>
|
| 628 |
+
#include <ATen/ops/fft_fftn_ops.h>
|
| 629 |
+
#include <ATen/ops/fft_fftshift_ops.h>
|
| 630 |
+
#include <ATen/ops/fft_hfft_ops.h>
|
| 631 |
+
#include <ATen/ops/fft_hfft2_ops.h>
|
| 632 |
+
#include <ATen/ops/fft_hfftn_ops.h>
|
| 633 |
+
#include <ATen/ops/fft_ifft_ops.h>
|
| 634 |
+
#include <ATen/ops/fft_ifft2_ops.h>
|
| 635 |
+
#include <ATen/ops/fft_ifftn_ops.h>
|
| 636 |
+
#include <ATen/ops/fft_ifftshift_ops.h>
|
| 637 |
+
#include <ATen/ops/fft_ihfft_ops.h>
|
| 638 |
+
#include <ATen/ops/fft_ihfft2_ops.h>
|
| 639 |
+
#include <ATen/ops/fft_ihfftn_ops.h>
|
| 640 |
+
#include <ATen/ops/fft_irfft_ops.h>
|
| 641 |
+
#include <ATen/ops/fft_irfft2_ops.h>
|
| 642 |
+
#include <ATen/ops/fft_irfftn_ops.h>
|
| 643 |
+
#include <ATen/ops/fft_rfft_ops.h>
|
| 644 |
+
#include <ATen/ops/fft_rfft2_ops.h>
|
| 645 |
+
#include <ATen/ops/fft_rfftfreq_ops.h>
|
| 646 |
+
#include <ATen/ops/fft_rfftn_ops.h>
|
| 647 |
+
#include <ATen/ops/fill_ops.h>
|
| 648 |
+
#include <ATen/ops/fill_diagonal_ops.h>
|
| 649 |
+
#include <ATen/ops/fix_ops.h>
|
| 650 |
+
#include <ATen/ops/flatten_ops.h>
|
| 651 |
+
#include <ATen/ops/flatten_dense_tensors_ops.h>
|
| 652 |
+
#include <ATen/ops/flip_ops.h>
|
| 653 |
+
#include <ATen/ops/fliplr_ops.h>
|
| 654 |
+
#include <ATen/ops/flipud_ops.h>
|
| 655 |
+
#include <ATen/ops/float_power_ops.h>
|
| 656 |
+
#include <ATen/ops/floor_ops.h>
|
| 657 |
+
#include <ATen/ops/floor_divide_ops.h>
|
| 658 |
+
#include <ATen/ops/fmax_ops.h>
|
| 659 |
+
#include <ATen/ops/fmin_ops.h>
|
| 660 |
+
#include <ATen/ops/fmod_ops.h>
|
| 661 |
+
#include <ATen/ops/frac_ops.h>
|
| 662 |
+
#include <ATen/ops/fractional_max_pool2d_ops.h>
|
| 663 |
+
#include <ATen/ops/fractional_max_pool2d_backward_ops.h>
|
| 664 |
+
#include <ATen/ops/fractional_max_pool3d_ops.h>
|
| 665 |
+
#include <ATen/ops/fractional_max_pool3d_backward_ops.h>
|
| 666 |
+
#include <ATen/ops/frexp_ops.h>
|
| 667 |
+
#include <ATen/ops/frobenius_norm_ops.h>
|
| 668 |
+
#include <ATen/ops/from_file_ops.h>
|
| 669 |
+
#include <ATen/ops/full_ops.h>
|
| 670 |
+
#include <ATen/ops/full_like_ops.h>
|
| 671 |
+
#include <ATen/ops/fused_moving_avg_obs_fake_quant_ops.h>
|
| 672 |
+
#include <ATen/ops/gather_ops.h>
|
| 673 |
+
#include <ATen/ops/gather_backward_ops.h>
|
| 674 |
+
#include <ATen/ops/gcd_ops.h>
|
| 675 |
+
#include <ATen/ops/ge_ops.h>
|
| 676 |
+
#include <ATen/ops/gelu_ops.h>
|
| 677 |
+
#include <ATen/ops/gelu_backward_ops.h>
|
| 678 |
+
#include <ATen/ops/geometric_ops.h>
|
| 679 |
+
#include <ATen/ops/geqrf_ops.h>
|
| 680 |
+
#include <ATen/ops/ger_ops.h>
|
| 681 |
+
#include <ATen/ops/glu_ops.h>
|
| 682 |
+
#include <ATen/ops/glu_backward_ops.h>
|
| 683 |
+
#include <ATen/ops/glu_backward_jvp_ops.h>
|
| 684 |
+
#include <ATen/ops/glu_jvp_ops.h>
|
| 685 |
+
#include <ATen/ops/gradient_ops.h>
|
| 686 |
+
#include <ATen/ops/greater_ops.h>
|
| 687 |
+
#include <ATen/ops/greater_equal_ops.h>
|
| 688 |
+
#include <ATen/ops/grid_sampler_ops.h>
|
| 689 |
+
#include <ATen/ops/grid_sampler_2d_ops.h>
|
| 690 |
+
#include <ATen/ops/grid_sampler_2d_backward_ops.h>
|
| 691 |
+
#include <ATen/ops/grid_sampler_3d_ops.h>
|
| 692 |
+
#include <ATen/ops/grid_sampler_3d_backward_ops.h>
|
| 693 |
+
#include <ATen/ops/group_norm_ops.h>
|
| 694 |
+
#include <ATen/ops/gru_ops.h>
|
| 695 |
+
#include <ATen/ops/gru_cell_ops.h>
|
| 696 |
+
#include <ATen/ops/gt_ops.h>
|
| 697 |
+
#include <ATen/ops/hamming_window_ops.h>
|
| 698 |
+
#include <ATen/ops/hann_window_ops.h>
|
| 699 |
+
#include <ATen/ops/hardshrink_ops.h>
|
| 700 |
+
#include <ATen/ops/hardshrink_backward_ops.h>
|
| 701 |
+
#include <ATen/ops/hardsigmoid_ops.h>
|
| 702 |
+
#include <ATen/ops/hardsigmoid_backward_ops.h>
|
| 703 |
+
#include <ATen/ops/hardswish_ops.h>
|
| 704 |
+
#include <ATen/ops/hardswish_backward_ops.h>
|
| 705 |
+
#include <ATen/ops/hardtanh_ops.h>
|
| 706 |
+
#include <ATen/ops/hardtanh_backward_ops.h>
|
| 707 |
+
#include <ATen/ops/heaviside_ops.h>
|
| 708 |
+
#include <ATen/ops/hinge_embedding_loss_ops.h>
|
| 709 |
+
#include <ATen/ops/histc_ops.h>
|
| 710 |
+
#include <ATen/ops/histogram_ops.h>
|
| 711 |
+
#include <ATen/ops/histogramdd_ops.h>
|
| 712 |
+
#include <ATen/ops/hsplit_ops.h>
|
| 713 |
+
#include <ATen/ops/hspmm_ops.h>
|
| 714 |
+
#include <ATen/ops/hstack_ops.h>
|
| 715 |
+
#include <ATen/ops/huber_loss_ops.h>
|
| 716 |
+
#include <ATen/ops/huber_loss_backward_ops.h>
|
| 717 |
+
#include <ATen/ops/hypot_ops.h>
|
| 718 |
+
#include <ATen/ops/i0_ops.h>
|
| 719 |
+
#include <ATen/ops/igamma_ops.h>
|
| 720 |
+
#include <ATen/ops/igammac_ops.h>
|
| 721 |
+
#include <ATen/ops/im2col_ops.h>
|
| 722 |
+
#include <ATen/ops/imag_ops.h>
|
| 723 |
+
#include <ATen/ops/index_ops.h>
|
| 724 |
+
#include <ATen/ops/index_add_ops.h>
|
| 725 |
+
#include <ATen/ops/index_copy_ops.h>
|
| 726 |
+
#include <ATen/ops/index_fill_ops.h>
|
| 727 |
+
#include <ATen/ops/index_put_ops.h>
|
| 728 |
+
#include <ATen/ops/index_reduce_ops.h>
|
| 729 |
+
#include <ATen/ops/index_select_ops.h>
|
| 730 |
+
#include <ATen/ops/index_select_backward_ops.h>
|
| 731 |
+
#include <ATen/ops/indices_ops.h>
|
| 732 |
+
#include <ATen/ops/indices_copy_ops.h>
|
| 733 |
+
#include <ATen/ops/infinitely_differentiable_gelu_backward_ops.h>
|
| 734 |
+
#include <ATen/ops/inner_ops.h>
|
| 735 |
+
#include <ATen/ops/instance_norm_ops.h>
|
| 736 |
+
#include <ATen/ops/int_repr_ops.h>
|
| 737 |
+
#include <ATen/ops/inverse_ops.h>
|
| 738 |
+
#include <ATen/ops/is_coalesced_ops.h>
|
| 739 |
+
#include <ATen/ops/is_complex_ops.h>
|
| 740 |
+
#include <ATen/ops/is_conj_ops.h>
|
| 741 |
+
#include <ATen/ops/is_distributed_ops.h>
|
| 742 |
+
#include <ATen/ops/is_floating_point_ops.h>
|
| 743 |
+
#include <ATen/ops/is_inference_ops.h>
|
| 744 |
+
#include <ATen/ops/is_leaf_ops.h>
|
| 745 |
+
#include <ATen/ops/is_neg_ops.h>
|
| 746 |
+
#include <ATen/ops/is_nonzero_ops.h>
|
| 747 |
+
#include <ATen/ops/is_pinned_ops.h>
|
| 748 |
+
#include <ATen/ops/is_same_size_ops.h>
|
| 749 |
+
#include <ATen/ops/is_set_to_ops.h>
|
| 750 |
+
#include <ATen/ops/is_signed_ops.h>
|
| 751 |
+
#include <ATen/ops/is_vulkan_available_ops.h>
|
| 752 |
+
#include <ATen/ops/isclose_ops.h>
|
| 753 |
+
#include <ATen/ops/isfinite_ops.h>
|
| 754 |
+
#include <ATen/ops/isin_ops.h>
|
| 755 |
+
#include <ATen/ops/isinf_ops.h>
|
| 756 |
+
#include <ATen/ops/isnan_ops.h>
|
| 757 |
+
#include <ATen/ops/isneginf_ops.h>
|
| 758 |
+
#include <ATen/ops/isposinf_ops.h>
|
| 759 |
+
#include <ATen/ops/isreal_ops.h>
|
| 760 |
+
#include <ATen/ops/istft_ops.h>
|
| 761 |
+
#include <ATen/ops/item_ops.h>
|
| 762 |
+
#include <ATen/ops/kaiser_window_ops.h>
|
| 763 |
+
#include <ATen/ops/kl_div_ops.h>
|
| 764 |
+
#include <ATen/ops/kron_ops.h>
|
| 765 |
+
#include <ATen/ops/kthvalue_ops.h>
|
| 766 |
+
#include <ATen/ops/l1_loss_ops.h>
|
| 767 |
+
#include <ATen/ops/layer_norm_ops.h>
|
| 768 |
+
#include <ATen/ops/lcm_ops.h>
|
| 769 |
+
#include <ATen/ops/ldexp_ops.h>
|
| 770 |
+
#include <ATen/ops/le_ops.h>
|
| 771 |
+
#include <ATen/ops/leaky_relu_ops.h>
|
| 772 |
+
#include <ATen/ops/leaky_relu_backward_ops.h>
|
| 773 |
+
#include <ATen/ops/lerp_ops.h>
|
| 774 |
+
#include <ATen/ops/less_ops.h>
|
| 775 |
+
#include <ATen/ops/less_equal_ops.h>
|
| 776 |
+
#include <ATen/ops/lgamma_ops.h>
|
| 777 |
+
#include <ATen/ops/lift_ops.h>
|
| 778 |
+
#include <ATen/ops/lift_fresh_ops.h>
|
| 779 |
+
#include <ATen/ops/lift_fresh_copy_ops.h>
|
| 780 |
+
#include <ATen/ops/linalg_cholesky_ops.h>
|
| 781 |
+
#include <ATen/ops/linalg_cholesky_ex_ops.h>
|
| 782 |
+
#include <ATen/ops/linalg_cond_ops.h>
|
| 783 |
+
#include <ATen/ops/linalg_cross_ops.h>
|
| 784 |
+
#include <ATen/ops/linalg_det_ops.h>
|
| 785 |
+
#include <ATen/ops/linalg_diagonal_ops.h>
|
| 786 |
+
#include <ATen/ops/linalg_eig_ops.h>
|
| 787 |
+
#include <ATen/ops/linalg_eigh_ops.h>
|
| 788 |
+
#include <ATen/ops/linalg_eigvals_ops.h>
|
| 789 |
+
#include <ATen/ops/linalg_eigvalsh_ops.h>
|
| 790 |
+
#include <ATen/ops/linalg_householder_product_ops.h>
|
| 791 |
+
#include <ATen/ops/linalg_inv_ops.h>
|
| 792 |
+
#include <ATen/ops/linalg_inv_ex_ops.h>
|
| 793 |
+
#include <ATen/ops/linalg_ldl_factor_ops.h>
|
| 794 |
+
#include <ATen/ops/linalg_ldl_factor_ex_ops.h>
|
| 795 |
+
#include <ATen/ops/linalg_ldl_solve_ops.h>
|
| 796 |
+
#include <ATen/ops/linalg_lstsq_ops.h>
|
| 797 |
+
#include <ATen/ops/linalg_lu_ops.h>
|
| 798 |
+
#include <ATen/ops/linalg_lu_factor_ops.h>
|
| 799 |
+
#include <ATen/ops/linalg_lu_factor_ex_ops.h>
|
| 800 |
+
#include <ATen/ops/linalg_lu_solve_ops.h>
|
| 801 |
+
#include <ATen/ops/linalg_matmul_ops.h>
|
| 802 |
+
#include <ATen/ops/linalg_matrix_exp_ops.h>
|
| 803 |
+
#include <ATen/ops/linalg_matrix_norm_ops.h>
|
| 804 |
+
#include <ATen/ops/linalg_matrix_power_ops.h>
|
| 805 |
+
#include <ATen/ops/linalg_matrix_rank_ops.h>
|
| 806 |
+
#include <ATen/ops/linalg_multi_dot_ops.h>
|
| 807 |
+
#include <ATen/ops/linalg_norm_ops.h>
|
| 808 |
+
#include <ATen/ops/linalg_pinv_ops.h>
|
| 809 |
+
#include <ATen/ops/linalg_qr_ops.h>
|
| 810 |
+
#include <ATen/ops/linalg_slogdet_ops.h>
|
| 811 |
+
#include <ATen/ops/linalg_solve_ops.h>
|
| 812 |
+
#include <ATen/ops/linalg_solve_ex_ops.h>
|
| 813 |
+
#include <ATen/ops/linalg_solve_triangular_ops.h>
|
| 814 |
+
#include <ATen/ops/linalg_svd_ops.h>
|
| 815 |
+
#include <ATen/ops/linalg_svdvals_ops.h>
|
| 816 |
+
#include <ATen/ops/linalg_tensorinv_ops.h>
|
| 817 |
+
#include <ATen/ops/linalg_tensorsolve_ops.h>
|
| 818 |
+
#include <ATen/ops/linalg_vander_ops.h>
|
| 819 |
+
#include <ATen/ops/linalg_vecdot_ops.h>
|
| 820 |
+
#include <ATen/ops/linalg_vector_norm_ops.h>
|
| 821 |
+
#include <ATen/ops/linear_ops.h>
|
| 822 |
+
#include <ATen/ops/linear_backward_ops.h>
|
| 823 |
+
#include <ATen/ops/linspace_ops.h>
|
| 824 |
+
#include <ATen/ops/log_ops.h>
|
| 825 |
+
#include <ATen/ops/log10_ops.h>
|
| 826 |
+
#include <ATen/ops/log1p_ops.h>
|
| 827 |
+
#include <ATen/ops/log2_ops.h>
|
| 828 |
+
#include <ATen/ops/log_normal_ops.h>
|
| 829 |
+
#include <ATen/ops/log_sigmoid_ops.h>
|
| 830 |
+
#include <ATen/ops/log_sigmoid_backward_ops.h>
|
| 831 |
+
#include <ATen/ops/log_sigmoid_forward_ops.h>
|
| 832 |
+
#include <ATen/ops/log_softmax_ops.h>
|
| 833 |
+
#include <ATen/ops/logaddexp_ops.h>
|
| 834 |
+
#include <ATen/ops/logaddexp2_ops.h>
|
| 835 |
+
#include <ATen/ops/logcumsumexp_ops.h>
|
| 836 |
+
#include <ATen/ops/logdet_ops.h>
|
| 837 |
+
#include <ATen/ops/logical_and_ops.h>
|
| 838 |
+
#include <ATen/ops/logical_not_ops.h>
|
| 839 |
+
#include <ATen/ops/logical_or_ops.h>
|
| 840 |
+
#include <ATen/ops/logical_xor_ops.h>
|
| 841 |
+
#include <ATen/ops/logit_ops.h>
|
| 842 |
+
#include <ATen/ops/logit_backward_ops.h>
|
| 843 |
+
#include <ATen/ops/logspace_ops.h>
|
| 844 |
+
#include <ATen/ops/logsumexp_ops.h>
|
| 845 |
+
#include <ATen/ops/lshift_ops.h>
|
| 846 |
+
#include <ATen/ops/lstm_ops.h>
|
| 847 |
+
#include <ATen/ops/lstm_cell_ops.h>
|
| 848 |
+
#include <ATen/ops/lstm_mps_backward_ops.h>
|
| 849 |
+
#include <ATen/ops/lt_ops.h>
|
| 850 |
+
#include <ATen/ops/lu_solve_ops.h>
|
| 851 |
+
#include <ATen/ops/lu_unpack_ops.h>
|
| 852 |
+
#include <ATen/ops/mH_ops.h>
|
| 853 |
+
#include <ATen/ops/mT_ops.h>
|
| 854 |
+
#include <ATen/ops/margin_ranking_loss_ops.h>
|
| 855 |
+
#include <ATen/ops/masked_fill_ops.h>
|
| 856 |
+
#include <ATen/ops/masked_scatter_ops.h>
|
| 857 |
+
#include <ATen/ops/masked_scatter_backward_ops.h>
|
| 858 |
+
#include <ATen/ops/masked_select_ops.h>
|
| 859 |
+
#include <ATen/ops/masked_select_backward_ops.h>
|
| 860 |
+
#include <ATen/ops/matmul_ops.h>
|
| 861 |
+
#include <ATen/ops/matmul_backward_ops.h>
|
| 862 |
+
#include <ATen/ops/matrix_H_ops.h>
|
| 863 |
+
#include <ATen/ops/matrix_exp_ops.h>
|
| 864 |
+
#include <ATen/ops/matrix_exp_backward_ops.h>
|
| 865 |
+
#include <ATen/ops/matrix_power_ops.h>
|
| 866 |
+
#include <ATen/ops/max_ops.h>
|
| 867 |
+
#include <ATen/ops/max_pool1d_ops.h>
|
| 868 |
+
#include <ATen/ops/max_pool1d_with_indices_ops.h>
|
| 869 |
+
#include <ATen/ops/max_pool2d_ops.h>
|
| 870 |
+
#include <ATen/ops/max_pool2d_backward_ops.h>
|
| 871 |
+
#include <ATen/ops/max_pool2d_with_indices_ops.h>
|
| 872 |
+
#include <ATen/ops/max_pool2d_with_indices_backward_ops.h>
|
| 873 |
+
#include <ATen/ops/max_pool3d_ops.h>
|
| 874 |
+
#include <ATen/ops/max_pool3d_with_indices_ops.h>
|
| 875 |
+
#include <ATen/ops/max_pool3d_with_indices_backward_ops.h>
|
| 876 |
+
#include <ATen/ops/max_unpool2d_ops.h>
|
| 877 |
+
#include <ATen/ops/max_unpool3d_ops.h>
|
| 878 |
+
#include <ATen/ops/maximum_ops.h>
|
| 879 |
+
#include <ATen/ops/mean_ops.h>
|
| 880 |
+
#include <ATen/ops/median_ops.h>
|
| 881 |
+
#include <ATen/ops/meshgrid_ops.h>
|
| 882 |
+
#include <ATen/ops/min_ops.h>
|
| 883 |
+
#include <ATen/ops/minimum_ops.h>
|
| 884 |
+
#include <ATen/ops/miopen_batch_norm_ops.h>
|
| 885 |
+
#include <ATen/ops/miopen_batch_norm_backward_ops.h>
|
| 886 |
+
#include <ATen/ops/miopen_convolution_ops.h>
|
| 887 |
+
#include <ATen/ops/miopen_convolution_add_relu_ops.h>
|
| 888 |
+
#include <ATen/ops/miopen_convolution_relu_ops.h>
|
| 889 |
+
#include <ATen/ops/miopen_convolution_transpose_ops.h>
|
| 890 |
+
#include <ATen/ops/miopen_depthwise_convolution_ops.h>
|
| 891 |
+
#include <ATen/ops/miopen_rnn_ops.h>
|
| 892 |
+
#include <ATen/ops/miopen_rnn_backward_ops.h>
|
| 893 |
+
#include <ATen/ops/mish_ops.h>
|
| 894 |
+
#include <ATen/ops/mish_backward_ops.h>
|
| 895 |
+
#include <ATen/ops/mkldnn_adaptive_avg_pool2d_ops.h>
|
| 896 |
+
#include <ATen/ops/mkldnn_adaptive_avg_pool2d_backward_ops.h>
|
| 897 |
+
#include <ATen/ops/mkldnn_convolution_ops.h>
|
| 898 |
+
#include <ATen/ops/mkldnn_linear_ops.h>
|
| 899 |
+
#include <ATen/ops/mkldnn_linear_backward_ops.h>
|
| 900 |
+
#include <ATen/ops/mkldnn_linear_backward_input_ops.h>
|
| 901 |
+
#include <ATen/ops/mkldnn_linear_backward_weights_ops.h>
|
| 902 |
+
#include <ATen/ops/mkldnn_max_pool2d_ops.h>
|
| 903 |
+
#include <ATen/ops/mkldnn_max_pool2d_backward_ops.h>
|
| 904 |
+
#include <ATen/ops/mkldnn_max_pool3d_ops.h>
|
| 905 |
+
#include <ATen/ops/mkldnn_max_pool3d_backward_ops.h>
|
| 906 |
+
#include <ATen/ops/mkldnn_reorder_conv2d_weight_ops.h>
|
| 907 |
+
#include <ATen/ops/mkldnn_reorder_conv3d_weight_ops.h>
|
| 908 |
+
#include <ATen/ops/mkldnn_rnn_layer_ops.h>
|
| 909 |
+
#include <ATen/ops/mkldnn_rnn_layer_backward_ops.h>
|
| 910 |
+
#include <ATen/ops/mm_ops.h>
|
| 911 |
+
#include <ATen/ops/mode_ops.h>
|
| 912 |
+
#include <ATen/ops/moveaxis_ops.h>
|
| 913 |
+
#include <ATen/ops/movedim_ops.h>
|
| 914 |
+
#include <ATen/ops/mps_convolution_backward_ops.h>
|
| 915 |
+
#include <ATen/ops/mps_convolution_transpose_backward_ops.h>
|
| 916 |
+
#include <ATen/ops/mse_loss_ops.h>
|
| 917 |
+
#include <ATen/ops/mse_loss_backward_ops.h>
|
| 918 |
+
#include <ATen/ops/msort_ops.h>
|
| 919 |
+
#include <ATen/ops/mul_ops.h>
|
| 920 |
+
#include <ATen/ops/multi_margin_loss_ops.h>
|
| 921 |
+
#include <ATen/ops/multi_margin_loss_backward_ops.h>
|
| 922 |
+
#include <ATen/ops/multilabel_margin_loss_ops.h>
|
| 923 |
+
#include <ATen/ops/multilabel_margin_loss_backward_ops.h>
|
| 924 |
+
#include <ATen/ops/multilabel_margin_loss_forward_ops.h>
|
| 925 |
+
#include <ATen/ops/multinomial_ops.h>
|
| 926 |
+
#include <ATen/ops/multiply_ops.h>
|
| 927 |
+
#include <ATen/ops/mv_ops.h>
|
| 928 |
+
#include <ATen/ops/mvlgamma_ops.h>
|
| 929 |
+
#include <ATen/ops/nan_to_num_ops.h>
|
| 930 |
+
#include <ATen/ops/nanmean_ops.h>
|
| 931 |
+
#include <ATen/ops/nanmedian_ops.h>
|
| 932 |
+
#include <ATen/ops/nanquantile_ops.h>
|
| 933 |
+
#include <ATen/ops/nansum_ops.h>
|
| 934 |
+
#include <ATen/ops/narrow_ops.h>
|
| 935 |
+
#include <ATen/ops/narrow_copy_ops.h>
|
| 936 |
+
#include <ATen/ops/native_batch_norm_ops.h>
|
| 937 |
+
#include <ATen/ops/native_batch_norm_backward_ops.h>
|
| 938 |
+
#include <ATen/ops/native_channel_shuffle_ops.h>
|
| 939 |
+
#include <ATen/ops/native_dropout_ops.h>
|
| 940 |
+
#include <ATen/ops/native_dropout_backward_ops.h>
|
| 941 |
+
#include <ATen/ops/native_group_norm_ops.h>
|
| 942 |
+
#include <ATen/ops/native_group_norm_backward_ops.h>
|
| 943 |
+
#include <ATen/ops/native_layer_norm_ops.h>
|
| 944 |
+
#include <ATen/ops/native_layer_norm_backward_ops.h>
|
| 945 |
+
#include <ATen/ops/native_norm_ops.h>
|
| 946 |
+
#include <ATen/ops/ne_ops.h>
|
| 947 |
+
#include <ATen/ops/neg_ops.h>
|
| 948 |
+
#include <ATen/ops/negative_ops.h>
|
| 949 |
+
#include <ATen/ops/nested_to_padded_tensor_ops.h>
|
| 950 |
+
#include <ATen/ops/new_empty_ops.h>
|
| 951 |
+
#include <ATen/ops/new_empty_strided_ops.h>
|
| 952 |
+
#include <ATen/ops/new_full_ops.h>
|
| 953 |
+
#include <ATen/ops/new_ones_ops.h>
|
| 954 |
+
#include <ATen/ops/new_zeros_ops.h>
|
| 955 |
+
#include <ATen/ops/nextafter_ops.h>
|
| 956 |
+
#include <ATen/ops/nll_loss_ops.h>
|
| 957 |
+
#include <ATen/ops/nll_loss2d_ops.h>
|
| 958 |
+
#include <ATen/ops/nll_loss2d_backward_ops.h>
|
| 959 |
+
#include <ATen/ops/nll_loss2d_forward_ops.h>
|
| 960 |
+
#include <ATen/ops/nll_loss_backward_ops.h>
|
| 961 |
+
#include <ATen/ops/nll_loss_forward_ops.h>
|
| 962 |
+
#include <ATen/ops/nll_loss_nd_ops.h>
|
| 963 |
+
#include <ATen/ops/nonzero_ops.h>
|
| 964 |
+
#include <ATen/ops/nonzero_numpy_ops.h>
|
| 965 |
+
#include <ATen/ops/nonzero_static_ops.h>
|
| 966 |
+
#include <ATen/ops/norm_ops.h>
|
| 967 |
+
#include <ATen/ops/norm_except_dim_ops.h>
|
| 968 |
+
#include <ATen/ops/normal_ops.h>
|
| 969 |
+
#include <ATen/ops/not_equal_ops.h>
|
| 970 |
+
#include <ATen/ops/nuclear_norm_ops.h>
|
| 971 |
+
#include <ATen/ops/numpy_T_ops.h>
|
| 972 |
+
#include <ATen/ops/one_hot_ops.h>
|
| 973 |
+
#include <ATen/ops/ones_ops.h>
|
| 974 |
+
#include <ATen/ops/ones_like_ops.h>
|
| 975 |
+
#include <ATen/ops/or_ops.h>
|
| 976 |
+
#include <ATen/ops/orgqr_ops.h>
|
| 977 |
+
#include <ATen/ops/ormqr_ops.h>
|
| 978 |
+
#include <ATen/ops/outer_ops.h>
|
| 979 |
+
#include <ATen/ops/output_nr_ops.h>
|
| 980 |
+
#include <ATen/ops/pad_ops.h>
|
| 981 |
+
#include <ATen/ops/pad_sequence_ops.h>
|
| 982 |
+
#include <ATen/ops/pairwise_distance_ops.h>
|
| 983 |
+
#include <ATen/ops/pdist_ops.h>
|
| 984 |
+
#include <ATen/ops/permute_ops.h>
|
| 985 |
+
#include <ATen/ops/permute_copy_ops.h>
|
| 986 |
+
#include <ATen/ops/pin_memory_ops.h>
|
| 987 |
+
#include <ATen/ops/pinverse_ops.h>
|
| 988 |
+
#include <ATen/ops/pixel_shuffle_ops.h>
|
| 989 |
+
#include <ATen/ops/pixel_unshuffle_ops.h>
|
| 990 |
+
#include <ATen/ops/poisson_ops.h>
|
| 991 |
+
#include <ATen/ops/poisson_nll_loss_ops.h>
|
| 992 |
+
#include <ATen/ops/polar_ops.h>
|
| 993 |
+
#include <ATen/ops/polygamma_ops.h>
|
| 994 |
+
#include <ATen/ops/positive_ops.h>
|
| 995 |
+
#include <ATen/ops/pow_ops.h>
|
| 996 |
+
#include <ATen/ops/prelu_ops.h>
|
| 997 |
+
#include <ATen/ops/prod_ops.h>
|
| 998 |
+
#include <ATen/ops/promote_types_ops.h>
|
| 999 |
+
#include <ATen/ops/put_ops.h>
|
| 1000 |
+
#include <ATen/ops/q_per_channel_axis_ops.h>
|
| 1001 |
+
#include <ATen/ops/q_per_channel_scales_ops.h>
|
| 1002 |
+
#include <ATen/ops/q_per_channel_zero_points_ops.h>
|
| 1003 |
+
#include <ATen/ops/q_scale_ops.h>
|
| 1004 |
+
#include <ATen/ops/q_zero_point_ops.h>
|
| 1005 |
+
#include <ATen/ops/qr_ops.h>
|
| 1006 |
+
#include <ATen/ops/qscheme_ops.h>
|
| 1007 |
+
#include <ATen/ops/quantile_ops.h>
|
| 1008 |
+
#include <ATen/ops/quantize_per_channel_ops.h>
|
| 1009 |
+
#include <ATen/ops/quantize_per_tensor_ops.h>
|
| 1010 |
+
#include <ATen/ops/quantize_per_tensor_dynamic_ops.h>
|
| 1011 |
+
#include <ATen/ops/quantized_batch_norm_ops.h>
|
| 1012 |
+
#include <ATen/ops/quantized_gru_cell_ops.h>
|
| 1013 |
+
#include <ATen/ops/quantized_lstm_cell_ops.h>
|
| 1014 |
+
#include <ATen/ops/quantized_max_pool1d_ops.h>
|
| 1015 |
+
#include <ATen/ops/quantized_max_pool2d_ops.h>
|
| 1016 |
+
#include <ATen/ops/quantized_max_pool3d_ops.h>
|
| 1017 |
+
#include <ATen/ops/quantized_rnn_relu_cell_ops.h>
|
| 1018 |
+
#include <ATen/ops/quantized_rnn_tanh_cell_ops.h>
|
| 1019 |
+
#include <ATen/ops/rad2deg_ops.h>
|
| 1020 |
+
#include <ATen/ops/rand_ops.h>
|
| 1021 |
+
#include <ATen/ops/rand_like_ops.h>
|
| 1022 |
+
#include <ATen/ops/randint_ops.h>
|
| 1023 |
+
#include <ATen/ops/randint_like_ops.h>
|
| 1024 |
+
#include <ATen/ops/randn_ops.h>
|
| 1025 |
+
#include <ATen/ops/randn_like_ops.h>
|
| 1026 |
+
#include <ATen/ops/random_ops.h>
|
| 1027 |
+
#include <ATen/ops/randperm_ops.h>
|
| 1028 |
+
#include <ATen/ops/range_ops.h>
|
| 1029 |
+
#include <ATen/ops/ravel_ops.h>
|
| 1030 |
+
#include <ATen/ops/real_ops.h>
|
| 1031 |
+
#include <ATen/ops/reciprocal_ops.h>
|
| 1032 |
+
#include <ATen/ops/record_stream_ops.h>
|
| 1033 |
+
#include <ATen/ops/refine_names_ops.h>
|
| 1034 |
+
#include <ATen/ops/reflection_pad1d_ops.h>
|
| 1035 |
+
#include <ATen/ops/reflection_pad1d_backward_ops.h>
|
| 1036 |
+
#include <ATen/ops/reflection_pad2d_ops.h>
|
| 1037 |
+
#include <ATen/ops/reflection_pad2d_backward_ops.h>
|
| 1038 |
+
#include <ATen/ops/reflection_pad3d_ops.h>
|
| 1039 |
+
#include <ATen/ops/reflection_pad3d_backward_ops.h>
|
| 1040 |
+
#include <ATen/ops/relu_ops.h>
|
| 1041 |
+
#include <ATen/ops/relu6_ops.h>
|
| 1042 |
+
#include <ATen/ops/remainder_ops.h>
|
| 1043 |
+
#include <ATen/ops/rename_ops.h>
|
| 1044 |
+
#include <ATen/ops/renorm_ops.h>
|
| 1045 |
+
#include <ATen/ops/repeat_ops.h>
|
| 1046 |
+
#include <ATen/ops/repeat_interleave_ops.h>
|
| 1047 |
+
#include <ATen/ops/replication_pad1d_ops.h>
|
| 1048 |
+
#include <ATen/ops/replication_pad1d_backward_ops.h>
|
| 1049 |
+
#include <ATen/ops/replication_pad2d_ops.h>
|
| 1050 |
+
#include <ATen/ops/replication_pad2d_backward_ops.h>
|
| 1051 |
+
#include <ATen/ops/replication_pad3d_ops.h>
|
| 1052 |
+
#include <ATen/ops/replication_pad3d_backward_ops.h>
|
| 1053 |
+
#include <ATen/ops/requires_grad_ops.h>
|
| 1054 |
+
#include <ATen/ops/reshape_ops.h>
|
| 1055 |
+
#include <ATen/ops/reshape_as_ops.h>
|
| 1056 |
+
#include <ATen/ops/resize_ops.h>
|
| 1057 |
+
#include <ATen/ops/resize_as_ops.h>
|
| 1058 |
+
#include <ATen/ops/resize_as_sparse_ops.h>
|
| 1059 |
+
#include <ATen/ops/resolve_conj_ops.h>
|
| 1060 |
+
#include <ATen/ops/resolve_neg_ops.h>
|
| 1061 |
+
#include <ATen/ops/result_type_ops.h>
|
| 1062 |
+
#include <ATen/ops/retain_grad_ops.h>
|
| 1063 |
+
#include <ATen/ops/retains_grad_ops.h>
|
| 1064 |
+
#include <ATen/ops/rnn_relu_ops.h>
|
| 1065 |
+
#include <ATen/ops/rnn_relu_cell_ops.h>
|
| 1066 |
+
#include <ATen/ops/rnn_tanh_ops.h>
|
| 1067 |
+
#include <ATen/ops/rnn_tanh_cell_ops.h>
|
| 1068 |
+
#include <ATen/ops/roll_ops.h>
|
| 1069 |
+
#include <ATen/ops/rot90_ops.h>
|
| 1070 |
+
#include <ATen/ops/round_ops.h>
|
| 1071 |
+
#include <ATen/ops/row_indices_ops.h>
|
| 1072 |
+
#include <ATen/ops/row_indices_copy_ops.h>
|
| 1073 |
+
#include <ATen/ops/row_stack_ops.h>
|
| 1074 |
+
#include <ATen/ops/rrelu_ops.h>
|
| 1075 |
+
#include <ATen/ops/rrelu_with_noise_ops.h>
|
| 1076 |
+
#include <ATen/ops/rrelu_with_noise_backward_ops.h>
|
| 1077 |
+
#include <ATen/ops/rshift_ops.h>
|
| 1078 |
+
#include <ATen/ops/rsqrt_ops.h>
|
| 1079 |
+
#include <ATen/ops/rsub_ops.h>
|
| 1080 |
+
#include <ATen/ops/scalar_tensor_ops.h>
|
| 1081 |
+
#include <ATen/ops/scaled_dot_product_attention_ops.h>
|
| 1082 |
+
#include <ATen/ops/scatter_ops.h>
|
| 1083 |
+
#include <ATen/ops/scatter_add_ops.h>
|
| 1084 |
+
#include <ATen/ops/scatter_reduce_ops.h>
|
| 1085 |
+
#include <ATen/ops/searchsorted_ops.h>
|
| 1086 |
+
#include <ATen/ops/segment_reduce_ops.h>
|
| 1087 |
+
#include <ATen/ops/select_ops.h>
|
| 1088 |
+
#include <ATen/ops/select_backward_ops.h>
|
| 1089 |
+
#include <ATen/ops/select_copy_ops.h>
|
| 1090 |
+
#include <ATen/ops/select_scatter_ops.h>
|
| 1091 |
+
#include <ATen/ops/selu_ops.h>
|
| 1092 |
+
#include <ATen/ops/set_ops.h>
|
| 1093 |
+
#include <ATen/ops/set_data_ops.h>
|
| 1094 |
+
#include <ATen/ops/sgn_ops.h>
|
| 1095 |
+
#include <ATen/ops/sigmoid_ops.h>
|
| 1096 |
+
#include <ATen/ops/sigmoid_backward_ops.h>
|
| 1097 |
+
#include <ATen/ops/sign_ops.h>
|
| 1098 |
+
#include <ATen/ops/signbit_ops.h>
|
| 1099 |
+
#include <ATen/ops/silu_ops.h>
|
| 1100 |
+
#include <ATen/ops/silu_backward_ops.h>
|
| 1101 |
+
#include <ATen/ops/sin_ops.h>
|
| 1102 |
+
#include <ATen/ops/sinc_ops.h>
|
| 1103 |
+
#include <ATen/ops/sinh_ops.h>
|
| 1104 |
+
#include <ATen/ops/size_ops.h>
|
| 1105 |
+
#include <ATen/ops/slice_ops.h>
|
| 1106 |
+
#include <ATen/ops/slice_backward_ops.h>
|
| 1107 |
+
#include <ATen/ops/slice_copy_ops.h>
|
| 1108 |
+
#include <ATen/ops/slice_inverse_ops.h>
|
| 1109 |
+
#include <ATen/ops/slice_scatter_ops.h>
|
| 1110 |
+
#include <ATen/ops/slogdet_ops.h>
|
| 1111 |
+
#include <ATen/ops/slow_conv3d_ops.h>
|
| 1112 |
+
#include <ATen/ops/slow_conv3d_forward_ops.h>
|
| 1113 |
+
#include <ATen/ops/slow_conv_dilated2d_ops.h>
|
| 1114 |
+
#include <ATen/ops/slow_conv_dilated3d_ops.h>
|
| 1115 |
+
#include <ATen/ops/slow_conv_transpose2d_ops.h>
|
| 1116 |
+
#include <ATen/ops/slow_conv_transpose3d_ops.h>
|
| 1117 |
+
#include <ATen/ops/smm_ops.h>
|
| 1118 |
+
#include <ATen/ops/smooth_l1_loss_ops.h>
|
| 1119 |
+
#include <ATen/ops/smooth_l1_loss_backward_ops.h>
|
| 1120 |
+
#include <ATen/ops/soft_margin_loss_ops.h>
|
| 1121 |
+
#include <ATen/ops/soft_margin_loss_backward_ops.h>
|
| 1122 |
+
#include <ATen/ops/softmax_ops.h>
|
| 1123 |
+
#include <ATen/ops/softplus_ops.h>
|
| 1124 |
+
#include <ATen/ops/softplus_backward_ops.h>
|
| 1125 |
+
#include <ATen/ops/softshrink_ops.h>
|
| 1126 |
+
#include <ATen/ops/softshrink_backward_ops.h>
|
| 1127 |
+
#include <ATen/ops/sort_ops.h>
|
| 1128 |
+
#include <ATen/ops/sparse_bsc_tensor_ops.h>
|
| 1129 |
+
#include <ATen/ops/sparse_bsr_tensor_ops.h>
|
| 1130 |
+
#include <ATen/ops/sparse_compressed_tensor_ops.h>
|
| 1131 |
+
#include <ATen/ops/sparse_coo_tensor_ops.h>
|
| 1132 |
+
#include <ATen/ops/sparse_csc_tensor_ops.h>
|
| 1133 |
+
#include <ATen/ops/sparse_csr_tensor_ops.h>
|
| 1134 |
+
#include <ATen/ops/sparse_dim_ops.h>
|
| 1135 |
+
#include <ATen/ops/sparse_mask_ops.h>
|
| 1136 |
+
#include <ATen/ops/sparse_resize_ops.h>
|
| 1137 |
+
#include <ATen/ops/sparse_resize_and_clear_ops.h>
|
| 1138 |
+
#include <ATen/ops/sparse_sampled_addmm_ops.h>
|
| 1139 |
+
#include <ATen/ops/special_airy_ai_ops.h>
|
| 1140 |
+
#include <ATen/ops/special_bessel_j0_ops.h>
|
| 1141 |
+
#include <ATen/ops/special_bessel_j1_ops.h>
|
| 1142 |
+
#include <ATen/ops/special_bessel_y0_ops.h>
|
| 1143 |
+
#include <ATen/ops/special_bessel_y1_ops.h>
|
| 1144 |
+
#include <ATen/ops/special_chebyshev_polynomial_t_ops.h>
|
| 1145 |
+
#include <ATen/ops/special_chebyshev_polynomial_u_ops.h>
|
| 1146 |
+
#include <ATen/ops/special_chebyshev_polynomial_v_ops.h>
|
| 1147 |
+
#include <ATen/ops/special_chebyshev_polynomial_w_ops.h>
|
| 1148 |
+
#include <ATen/ops/special_digamma_ops.h>
|
| 1149 |
+
#include <ATen/ops/special_entr_ops.h>
|
| 1150 |
+
#include <ATen/ops/special_erf_ops.h>
|
| 1151 |
+
#include <ATen/ops/special_erfc_ops.h>
|
| 1152 |
+
#include <ATen/ops/special_erfcx_ops.h>
|
| 1153 |
+
#include <ATen/ops/special_erfinv_ops.h>
|
| 1154 |
+
#include <ATen/ops/special_exp2_ops.h>
|
| 1155 |
+
#include <ATen/ops/special_expit_ops.h>
|
| 1156 |
+
#include <ATen/ops/special_expm1_ops.h>
|
| 1157 |
+
#include <ATen/ops/special_gammainc_ops.h>
|
| 1158 |
+
#include <ATen/ops/special_gammaincc_ops.h>
|
| 1159 |
+
#include <ATen/ops/special_gammaln_ops.h>
|
| 1160 |
+
#include <ATen/ops/special_hermite_polynomial_h_ops.h>
|
| 1161 |
+
#include <ATen/ops/special_hermite_polynomial_he_ops.h>
|
| 1162 |
+
#include <ATen/ops/special_i0_ops.h>
|
| 1163 |
+
#include <ATen/ops/special_i0e_ops.h>
|
| 1164 |
+
#include <ATen/ops/special_i1_ops.h>
|
| 1165 |
+
#include <ATen/ops/special_i1e_ops.h>
|
| 1166 |
+
#include <ATen/ops/special_laguerre_polynomial_l_ops.h>
|
| 1167 |
+
#include <ATen/ops/special_legendre_polynomial_p_ops.h>
|
| 1168 |
+
#include <ATen/ops/special_log1p_ops.h>
|
| 1169 |
+
#include <ATen/ops/special_log_ndtr_ops.h>
|
| 1170 |
+
#include <ATen/ops/special_log_softmax_ops.h>
|
| 1171 |
+
#include <ATen/ops/special_logit_ops.h>
|
| 1172 |
+
#include <ATen/ops/special_logsumexp_ops.h>
|
| 1173 |
+
#include <ATen/ops/special_modified_bessel_i0_ops.h>
|
| 1174 |
+
#include <ATen/ops/special_modified_bessel_i1_ops.h>
|
| 1175 |
+
#include <ATen/ops/special_modified_bessel_k0_ops.h>
|
| 1176 |
+
#include <ATen/ops/special_modified_bessel_k1_ops.h>
|
| 1177 |
+
#include <ATen/ops/special_multigammaln_ops.h>
|
| 1178 |
+
#include <ATen/ops/special_ndtr_ops.h>
|
| 1179 |
+
#include <ATen/ops/special_ndtri_ops.h>
|
| 1180 |
+
#include <ATen/ops/special_polygamma_ops.h>
|
| 1181 |
+
#include <ATen/ops/special_psi_ops.h>
|
| 1182 |
+
#include <ATen/ops/special_round_ops.h>
|
| 1183 |
+
#include <ATen/ops/special_scaled_modified_bessel_k0_ops.h>
|
| 1184 |
+
#include <ATen/ops/special_scaled_modified_bessel_k1_ops.h>
|
| 1185 |
+
#include <ATen/ops/special_shifted_chebyshev_polynomial_t_ops.h>
|
| 1186 |
+
#include <ATen/ops/special_shifted_chebyshev_polynomial_u_ops.h>
|
| 1187 |
+
#include <ATen/ops/special_shifted_chebyshev_polynomial_v_ops.h>
|
| 1188 |
+
#include <ATen/ops/special_shifted_chebyshev_polynomial_w_ops.h>
|
| 1189 |
+
#include <ATen/ops/special_sinc_ops.h>
|
| 1190 |
+
#include <ATen/ops/special_softmax_ops.h>
|
| 1191 |
+
#include <ATen/ops/special_spherical_bessel_j0_ops.h>
|
| 1192 |
+
#include <ATen/ops/special_xlog1py_ops.h>
|
| 1193 |
+
#include <ATen/ops/special_xlogy_ops.h>
|
| 1194 |
+
#include <ATen/ops/special_zeta_ops.h>
|
| 1195 |
+
#include <ATen/ops/split_ops.h>
|
| 1196 |
+
#include <ATen/ops/split_copy_ops.h>
|
| 1197 |
+
#include <ATen/ops/split_with_sizes_ops.h>
|
| 1198 |
+
#include <ATen/ops/split_with_sizes_copy_ops.h>
|
| 1199 |
+
#include <ATen/ops/sqrt_ops.h>
|
| 1200 |
+
#include <ATen/ops/square_ops.h>
|
| 1201 |
+
#include <ATen/ops/squeeze_ops.h>
|
| 1202 |
+
#include <ATen/ops/squeeze_copy_ops.h>
|
| 1203 |
+
#include <ATen/ops/sspaddmm_ops.h>
|
| 1204 |
+
#include <ATen/ops/stack_ops.h>
|
| 1205 |
+
#include <ATen/ops/std_ops.h>
|
| 1206 |
+
#include <ATen/ops/std_mean_ops.h>
|
| 1207 |
+
#include <ATen/ops/stft_ops.h>
|
| 1208 |
+
#include <ATen/ops/stride_ops.h>
|
| 1209 |
+
#include <ATen/ops/sub_ops.h>
|
| 1210 |
+
#include <ATen/ops/subtract_ops.h>
|
| 1211 |
+
#include <ATen/ops/sum_ops.h>
|
| 1212 |
+
#include <ATen/ops/sum_to_size_ops.h>
|
| 1213 |
+
#include <ATen/ops/svd_ops.h>
|
| 1214 |
+
#include <ATen/ops/swapaxes_ops.h>
|
| 1215 |
+
#include <ATen/ops/swapdims_ops.h>
|
| 1216 |
+
#include <ATen/ops/sym_constrain_range_ops.h>
|
| 1217 |
+
#include <ATen/ops/sym_constrain_range_for_size_ops.h>
|
| 1218 |
+
#include <ATen/ops/sym_numel_ops.h>
|
| 1219 |
+
#include <ATen/ops/sym_size_ops.h>
|
| 1220 |
+
#include <ATen/ops/sym_storage_offset_ops.h>
|
| 1221 |
+
#include <ATen/ops/sym_stride_ops.h>
|
| 1222 |
+
#include <ATen/ops/t_ops.h>
|
| 1223 |
+
#include <ATen/ops/t_copy_ops.h>
|
| 1224 |
+
#include <ATen/ops/take_ops.h>
|
| 1225 |
+
#include <ATen/ops/take_along_dim_ops.h>
|
| 1226 |
+
#include <ATen/ops/tan_ops.h>
|
| 1227 |
+
#include <ATen/ops/tanh_ops.h>
|
| 1228 |
+
#include <ATen/ops/tanh_backward_ops.h>
|
| 1229 |
+
#include <ATen/ops/tensor_split_ops.h>
|
| 1230 |
+
#include <ATen/ops/tensordot_ops.h>
|
| 1231 |
+
#include <ATen/ops/thnn_conv2d_ops.h>
|
| 1232 |
+
#include <ATen/ops/threshold_ops.h>
|
| 1233 |
+
#include <ATen/ops/threshold_backward_ops.h>
|
| 1234 |
+
#include <ATen/ops/tile_ops.h>
|
| 1235 |
+
#include <ATen/ops/to_ops.h>
|
| 1236 |
+
#include <ATen/ops/to_dense_ops.h>
|
| 1237 |
+
#include <ATen/ops/to_dense_backward_ops.h>
|
| 1238 |
+
#include <ATen/ops/to_mkldnn_ops.h>
|
| 1239 |
+
#include <ATen/ops/to_mkldnn_backward_ops.h>
|
| 1240 |
+
#include <ATen/ops/to_padded_tensor_ops.h>
|
| 1241 |
+
#include <ATen/ops/to_sparse_ops.h>
|
| 1242 |
+
#include <ATen/ops/to_sparse_bsc_ops.h>
|
| 1243 |
+
#include <ATen/ops/to_sparse_bsr_ops.h>
|
| 1244 |
+
#include <ATen/ops/to_sparse_csc_ops.h>
|
| 1245 |
+
#include <ATen/ops/to_sparse_csr_ops.h>
|
| 1246 |
+
#include <ATen/ops/topk_ops.h>
|
| 1247 |
+
#include <ATen/ops/trace_ops.h>
|
| 1248 |
+
#include <ATen/ops/trace_backward_ops.h>
|
| 1249 |
+
#include <ATen/ops/transpose_ops.h>
|
| 1250 |
+
#include <ATen/ops/transpose_copy_ops.h>
|
| 1251 |
+
#include <ATen/ops/trapezoid_ops.h>
|
| 1252 |
+
#include <ATen/ops/trapz_ops.h>
|
| 1253 |
+
#include <ATen/ops/triangular_solve_ops.h>
|
| 1254 |
+
#include <ATen/ops/tril_ops.h>
|
| 1255 |
+
#include <ATen/ops/tril_indices_ops.h>
|
| 1256 |
+
#include <ATen/ops/triplet_margin_loss_ops.h>
|
| 1257 |
+
#include <ATen/ops/triu_ops.h>
|
| 1258 |
+
#include <ATen/ops/triu_indices_ops.h>
|
| 1259 |
+
#include <ATen/ops/true_divide_ops.h>
|
| 1260 |
+
#include <ATen/ops/trunc_ops.h>
|
| 1261 |
+
#include <ATen/ops/type_as_ops.h>
|
| 1262 |
+
#include <ATen/ops/unbind_ops.h>
|
| 1263 |
+
#include <ATen/ops/unbind_copy_ops.h>
|
| 1264 |
+
#include <ATen/ops/unflatten_ops.h>
|
| 1265 |
+
#include <ATen/ops/unflatten_dense_tensors_ops.h>
|
| 1266 |
+
#include <ATen/ops/unfold_ops.h>
|
| 1267 |
+
#include <ATen/ops/unfold_backward_ops.h>
|
| 1268 |
+
#include <ATen/ops/unfold_copy_ops.h>
|
| 1269 |
+
#include <ATen/ops/uniform_ops.h>
|
| 1270 |
+
#include <ATen/ops/unique_consecutive_ops.h>
|
| 1271 |
+
#include <ATen/ops/unique_dim_ops.h>
|
| 1272 |
+
#include <ATen/ops/unique_dim_consecutive_ops.h>
|
| 1273 |
+
#include <ATen/ops/unsafe_chunk_ops.h>
|
| 1274 |
+
#include <ATen/ops/unsafe_split_ops.h>
|
| 1275 |
+
#include <ATen/ops/unsafe_split_with_sizes_ops.h>
|
| 1276 |
+
#include <ATen/ops/unsqueeze_ops.h>
|
| 1277 |
+
#include <ATen/ops/unsqueeze_copy_ops.h>
|
| 1278 |
+
#include <ATen/ops/upsample_bicubic2d_ops.h>
|
| 1279 |
+
#include <ATen/ops/upsample_bicubic2d_backward_ops.h>
|
| 1280 |
+
#include <ATen/ops/upsample_bilinear2d_ops.h>
|
| 1281 |
+
#include <ATen/ops/upsample_bilinear2d_backward_ops.h>
|
| 1282 |
+
#include <ATen/ops/upsample_linear1d_ops.h>
|
| 1283 |
+
#include <ATen/ops/upsample_linear1d_backward_ops.h>
|
| 1284 |
+
#include <ATen/ops/upsample_nearest1d_ops.h>
|
| 1285 |
+
#include <ATen/ops/upsample_nearest1d_backward_ops.h>
|
| 1286 |
+
#include <ATen/ops/upsample_nearest2d_ops.h>
|
| 1287 |
+
#include <ATen/ops/upsample_nearest2d_backward_ops.h>
|
| 1288 |
+
#include <ATen/ops/upsample_nearest3d_ops.h>
|
| 1289 |
+
#include <ATen/ops/upsample_nearest3d_backward_ops.h>
|
| 1290 |
+
#include <ATen/ops/upsample_trilinear3d_ops.h>
|
| 1291 |
+
#include <ATen/ops/upsample_trilinear3d_backward_ops.h>
|
| 1292 |
+
#include <ATen/ops/value_selecting_reduction_backward_ops.h>
|
| 1293 |
+
#include <ATen/ops/values_ops.h>
|
| 1294 |
+
#include <ATen/ops/values_copy_ops.h>
|
| 1295 |
+
#include <ATen/ops/vander_ops.h>
|
| 1296 |
+
#include <ATen/ops/var_ops.h>
|
| 1297 |
+
#include <ATen/ops/var_mean_ops.h>
|
| 1298 |
+
#include <ATen/ops/vdot_ops.h>
|
| 1299 |
+
#include <ATen/ops/view_ops.h>
|
| 1300 |
+
#include <ATen/ops/view_as_ops.h>
|
| 1301 |
+
#include <ATen/ops/view_as_complex_ops.h>
|
| 1302 |
+
#include <ATen/ops/view_as_complex_copy_ops.h>
|
| 1303 |
+
#include <ATen/ops/view_as_real_ops.h>
|
| 1304 |
+
#include <ATen/ops/view_as_real_copy_ops.h>
|
| 1305 |
+
#include <ATen/ops/view_copy_ops.h>
|
| 1306 |
+
#include <ATen/ops/vsplit_ops.h>
|
| 1307 |
+
#include <ATen/ops/vstack_ops.h>
|
| 1308 |
+
#include <ATen/ops/where_ops.h>
|
| 1309 |
+
#include <ATen/ops/xlogy_ops.h>
|
| 1310 |
+
#include <ATen/ops/xor_ops.h>
|
| 1311 |
+
#include <ATen/ops/zero_ops.h>
|
| 1312 |
+
#include <ATen/ops/zeros_ops.h>
|
| 1313 |
+
#include <ATen/ops/zeros_like_ops.h>
|
| 1314 |
+
|
| 1315 |
+
// Extension writers: do you write wrapper functions? Are you frustrated with
|
| 1316 |
+
// resolving overloads of operators? Are you frustrated with dealing with
|
| 1317 |
+
// pointer-to-methods and resolving overloads of pointer-to-methods?? Look no
|
| 1318 |
+
// further, this is the utility for you.
|
| 1319 |
+
//
|
| 1320 |
+
// Given an operator schema: aten::op.overload(...
|
| 1321 |
+
//
|
| 1322 |
+
// Use ATEN_FN2(op, overload) to get a *function* version of the operator
|
| 1323 |
+
// that is guaranteed to not be overloaded. This means that you can safely
|
| 1324 |
+
// decltype(&ATEN_FN2(op, overload)) it. NB: the 2 means this macro takes 2 args.
|
| 1325 |
+
//
|
| 1326 |
+
// Given an operator schema without an overload name: aten::op(...
|
| 1327 |
+
//
|
| 1328 |
+
// Use ATEN_FN(op) to get an unambiguous *function* version of the operator.
|
| 1329 |
+
//
|
| 1330 |
+
// There is some interesting behavior for out= operations.
|
| 1331 |
+
// ATEN_FN2(sin, out) gives a function that is *faithful* to the schema;
|
| 1332 |
+
// that is, the order of arguments is exactly what it looks like in the schema.
|
| 1333 |
+
|
| 1334 |
+
#define ATEN_FN2(op_name, overload) at::_ops::op_name##_##overload::call
|
| 1335 |
+
#define ATEN_FN(op_name) at::_ops::op_name::call
|
| 1336 |
+
|
| 1337 |
+
// Separately, ATEN_OP(op) and ATEN_OP2(op, overload) define a class containing compile-time
|
| 1338 |
+
// metadata about a given aten operator.
|
| 1339 |
+
// Notable data on the class includes:
|
| 1340 |
+
// - ATEN_OP2(add, Tensor)::name // returns the string name: "add"
|
| 1341 |
+
// - ATEN_OP2(add, Tensor)::overload_name // returns the string overload name: "Tensor"
|
| 1342 |
+
// - ATEN_OP2(add, Tensor)::schema // returns the C++ schema type: at::Tensor (const at::Tensor &, const at::Tensor &, const at::Scalar &)
|
| 1343 |
+
// - ATEN_OP2(add, Tensor)::schema_str // returns the string jit type: "add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor"
|
| 1344 |
+
|
| 1345 |
+
#define ATEN_OP2(op_name, overload) at::_ops::op_name##_##overload
|
| 1346 |
+
#define ATEN_OP(op_name) at::_ops::op_name
|
| 1347 |
+
|
| 1348 |
+
// WARNING: Please do not call any of the ops in the _ops namespace directly.
|
| 1349 |
+
// Use the ATEN_FN macros. We do not guarantee stability of the naming
|
| 1350 |
+
// scheme for the functions in at::_ops
|
| 1351 |
+
|
| 1352 |
+
// See Note [The ATen Operators API] for details of the at::_ops namespace
|
| 1353 |
+
|
| 1354 |
+
namespace at {
|
| 1355 |
+
namespace _ops {
|
| 1356 |
+
|
| 1357 |
+
} // namespace _ops
|
| 1358 |
+
} // namespace at
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/PTThreadPool.h
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/Parallel.h>
|
| 4 |
+
#include <c10/core/thread_pool.h>
|
| 5 |
+
|
| 6 |
+
namespace at {
|
| 7 |
+
|
| 8 |
+
class TORCH_API PTThreadPool : public c10::ThreadPool {
|
| 9 |
+
public:
|
| 10 |
+
explicit PTThreadPool(int pool_size, int numa_node_id = -1)
|
| 11 |
+
: c10::ThreadPool(pool_size, numa_node_id, []() {
|
| 12 |
+
c10::setThreadName("PTThreadPool");
|
| 13 |
+
at::init_num_threads();
|
| 14 |
+
}) {}
|
| 15 |
+
};
|
| 16 |
+
|
| 17 |
+
} // namespace at
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ParallelNativeTBB.h
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <atomic>
|
| 4 |
+
#include <cstddef>
|
| 5 |
+
#include <exception>
|
| 6 |
+
|
| 7 |
+
#include <c10/util/Exception.h>
|
| 8 |
+
|
| 9 |
+
#ifdef _WIN32
|
| 10 |
+
#ifndef WIN32_LEAN_AND_MEAN
|
| 11 |
+
#define WIN32_LEAN_AND_MEAN
|
| 12 |
+
#endif
|
| 13 |
+
#endif
|
| 14 |
+
#include <tbb/tbb.h>
|
| 15 |
+
|
| 16 |
+
#define INTRA_OP_PARALLEL
|
| 17 |
+
|
| 18 |
+
namespace at::internal {
|
| 19 |
+
|
| 20 |
+
template <typename F>
|
| 21 |
+
inline void invoke_parallel(
|
| 22 |
+
const int64_t begin,
|
| 23 |
+
const int64_t end,
|
| 24 |
+
const int64_t grain_size,
|
| 25 |
+
const F& f) {
|
| 26 |
+
// Choose number of tasks based on grain size and number of threads.
|
| 27 |
+
int64_t chunk_size = divup((end - begin), get_num_threads());
|
| 28 |
+
// Make sure each task is at least grain_size size.
|
| 29 |
+
chunk_size = std::max(grain_size, chunk_size);
|
| 30 |
+
|
| 31 |
+
std::atomic_flag err_flag = ATOMIC_FLAG_INIT;
|
| 32 |
+
std::exception_ptr eptr;
|
| 33 |
+
tbb::parallel_for(
|
| 34 |
+
tbb::blocked_range<int64_t>(begin, end, chunk_size),
|
| 35 |
+
[&eptr, &err_flag, f](const tbb::blocked_range<int64_t>& r) {
|
| 36 |
+
try {
|
| 37 |
+
internal::ThreadIdGuard tid_guard(
|
| 38 |
+
tbb::this_task_arena::current_thread_index());
|
| 39 |
+
f(r.begin(), r.end());
|
| 40 |
+
} catch (...) {
|
| 41 |
+
if (!err_flag.test_and_set()) {
|
| 42 |
+
eptr = std::current_exception();
|
| 43 |
+
}
|
| 44 |
+
}
|
| 45 |
+
},
|
| 46 |
+
tbb::static_partitioner{});
|
| 47 |
+
if (eptr) {
|
| 48 |
+
std::rethrow_exception(eptr);
|
| 49 |
+
}
|
| 50 |
+
}
|
| 51 |
+
|
| 52 |
+
} // namespace at::internal
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/PythonTorchFunctionTLS.h
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <c10/core/SafePyObject.h>
|
| 4 |
+
#include <c10/macros/Macros.h>
|
| 5 |
+
|
| 6 |
+
namespace at::impl {
|
| 7 |
+
|
| 8 |
+
enum TorchFunctionDisabledState { ENABLED, SUBCLASSES_DISABLED, ALL_DISABLED };
|
| 9 |
+
|
| 10 |
+
struct TORCH_API PythonTorchFunctionTLS {
|
| 11 |
+
static void set_disabled_state(TorchFunctionDisabledState disabled_state_);
|
| 12 |
+
static TorchFunctionDisabledState get_disabled_state();
|
| 13 |
+
|
| 14 |
+
static void push_onto_stack(std::shared_ptr<SafePyObject> mode);
|
| 15 |
+
static const std::shared_ptr<SafePyObject> pop_stack();
|
| 16 |
+
static const std::shared_ptr<SafePyObject>& get_stack_at(int64_t idx);
|
| 17 |
+
static int64_t stack_len();
|
| 18 |
+
|
| 19 |
+
static const PythonTorchFunctionTLS& get_state();
|
| 20 |
+
static void set_state(const PythonTorchFunctionTLS& state);
|
| 21 |
+
|
| 22 |
+
private:
|
| 23 |
+
// The mode TLS is split into
|
| 24 |
+
// - disabled_state, which says which part of torch function are disabled
|
| 25 |
+
// - stack_, which is a vector of modes representing the stack of user
|
| 26 |
+
// defined modes
|
| 27 |
+
TorchFunctionDisabledState disabled_state_ =
|
| 28 |
+
TorchFunctionDisabledState::ENABLED;
|
| 29 |
+
std::vector<std::shared_ptr<c10::SafePyObject>> stack_;
|
| 30 |
+
};
|
| 31 |
+
|
| 32 |
+
TORCH_API bool torch_function_mode_enabled();
|
| 33 |
+
|
| 34 |
+
} // namespace at::impl
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/SequenceNumber.h
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <c10/macros/Export.h>
|
| 4 |
+
#include <cstdint>
|
| 5 |
+
|
| 6 |
+
// A simple thread local enumeration, used to link forward and backward pass
|
| 7 |
+
// ops and is used by autograd and observers framework
|
| 8 |
+
namespace at::sequence_number {
|
| 9 |
+
|
| 10 |
+
TORCH_API uint64_t peek();
|
| 11 |
+
TORCH_API uint64_t get_and_increment();
|
| 12 |
+
|
| 13 |
+
} // namespace at::sequence_number
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/SparseTensorImpl.h
ADDED
|
@@ -0,0 +1,400 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/Tensor.h>
|
| 4 |
+
#include <c10/core/TensorImpl.h>
|
| 5 |
+
#include <c10/util/Exception.h>
|
| 6 |
+
#include <c10/util/irange.h>
|
| 7 |
+
|
| 8 |
+
#ifndef AT_PER_OPERATOR_HEADERS
|
| 9 |
+
#include <ATen/Functions.h>
|
| 10 |
+
#else
|
| 11 |
+
#include <ATen/ops/empty.h>
|
| 12 |
+
#include <ATen/ops/resize.h>
|
| 13 |
+
#endif
|
| 14 |
+
|
| 15 |
+
namespace at {
|
| 16 |
+
struct TORCH_API SparseTensorImpl : public TensorImpl {
|
| 17 |
+
// Stored in COO format, indices + values.
|
| 18 |
+
|
| 19 |
+
// INVARIANTS:
|
| 20 |
+
// sparse_dim: range [0, len(shape)]; sparse_dim + dense_dim = len(shape)
|
| 21 |
+
// dense_dim : range [0, len(shape)]; sparse_dim + dense_dim = len(shape)
|
| 22 |
+
// _indices.shape: dimensionality: 2, shape: (sparse_dim, nnz)
|
| 23 |
+
// _values.shape: dimensionality: 1 + dense_dim. shape: (nnz,
|
| 24 |
+
// shape[sparse_dim:])
|
| 25 |
+
|
| 26 |
+
int64_t sparse_dim_ = 0; // number of sparse dimensions
|
| 27 |
+
int64_t dense_dim_ = 0; // number of dense dimensions
|
| 28 |
+
|
| 29 |
+
Tensor indices_; // always a LongTensor
|
| 30 |
+
Tensor values_;
|
| 31 |
+
|
| 32 |
+
// A sparse tensor is 'coalesced' if every index occurs at most once in
|
| 33 |
+
// the indices tensor, and the indices are in sorted order. (This means
|
| 34 |
+
// that it is very easy to convert a coalesced tensor to CSR format: you
|
| 35 |
+
// need only compute CSR format indices.)
|
| 36 |
+
//
|
| 37 |
+
// Most math operations can only be performed on coalesced sparse tensors,
|
| 38 |
+
// because many algorithms proceed by merging two sorted lists (of indices).
|
| 39 |
+
bool coalesced_ = false;
|
| 40 |
+
|
| 41 |
+
// compute_numel with integer multiplication overflow check, see gh-57542
|
| 42 |
+
void refresh_numel() {
|
| 43 |
+
TensorImpl::safe_refresh_numel();
|
| 44 |
+
}
|
| 45 |
+
|
| 46 |
+
public:
|
| 47 |
+
// Public for now...
|
| 48 |
+
explicit SparseTensorImpl(at::DispatchKeySet, const caffe2::TypeMeta);
|
| 49 |
+
|
| 50 |
+
void release_resources() override;
|
| 51 |
+
|
| 52 |
+
int64_t nnz() const {
|
| 53 |
+
return values_.size(0);
|
| 54 |
+
}
|
| 55 |
+
|
| 56 |
+
c10::SymInt sym_nnz() const {
|
| 57 |
+
return values_.sym_size(0);
|
| 58 |
+
}
|
| 59 |
+
int64_t sparse_dim() const {
|
| 60 |
+
return sparse_dim_;
|
| 61 |
+
}
|
| 62 |
+
int64_t dense_dim() const {
|
| 63 |
+
return dense_dim_;
|
| 64 |
+
}
|
| 65 |
+
bool coalesced() const {
|
| 66 |
+
return coalesced_;
|
| 67 |
+
}
|
| 68 |
+
Tensor indices() const {
|
| 69 |
+
return indices_;
|
| 70 |
+
}
|
| 71 |
+
Tensor values() const {
|
| 72 |
+
return values_;
|
| 73 |
+
}
|
| 74 |
+
|
| 75 |
+
void set_size(int64_t dim, int64_t new_size) override;
|
| 76 |
+
void set_stride(int64_t dim, int64_t new_stride) override;
|
| 77 |
+
void set_storage_offset(int64_t storage_offset) override;
|
| 78 |
+
|
| 79 |
+
#ifdef DEBUG
|
| 80 |
+
bool has_storage() const override;
|
| 81 |
+
#endif
|
| 82 |
+
|
| 83 |
+
// WARNING: This function does NOT preserve invariants of sparse_dim/dense_dim
|
| 84 |
+
// with respect to indices and values
|
| 85 |
+
void raw_resize_(int64_t sparse_dim, int64_t dense_dim, IntArrayRef size) {
|
| 86 |
+
TORCH_CHECK(
|
| 87 |
+
allow_tensor_metadata_change(),
|
| 88 |
+
"raw_resize_ ",
|
| 89 |
+
err_msg_tensor_metadata_change_not_allowed);
|
| 90 |
+
TORCH_CHECK(
|
| 91 |
+
!has_symbolic_sizes_strides_,
|
| 92 |
+
"raw_resize_ called on tensor with symbolic shape")
|
| 93 |
+
set_sizes_and_strides(size, std::vector<int64_t>(size.size()));
|
| 94 |
+
sparse_dim_ = sparse_dim;
|
| 95 |
+
dense_dim_ = dense_dim;
|
| 96 |
+
refresh_numel();
|
| 97 |
+
}
|
| 98 |
+
|
| 99 |
+
// NOTE: This function preserves invariants of sparse_dim/dense_dim with
|
| 100 |
+
// respect to indices and values.
|
| 101 |
+
//
|
| 102 |
+
// NOTE: This function supports the following cases:
|
| 103 |
+
// 1. When we keep the number of dense dimensions unchanged, and NOT shrinking
|
| 104 |
+
// the size of any of the dense dimensions.
|
| 105 |
+
// 2. When we keep the number of sparse dimensions unchanged, and NOT
|
| 106 |
+
// shrinking the size of any of the sparse dimensions.
|
| 107 |
+
// 3. When the sparse tensor has zero nnz, in which case we are free to change
|
| 108 |
+
// the shapes of both its sparse and dense dimensions.
|
| 109 |
+
//
|
| 110 |
+
// This function DOESN'T support (and will throw an error) the following
|
| 111 |
+
// cases:
|
| 112 |
+
// 1. When we attempt to change the number of sparse dimensions on a non-empty
|
| 113 |
+
// sparse tensor (such an operation will invalidate the indices stored).
|
| 114 |
+
// 2. When we attempt to change the number of dense dimensions on a non-empty
|
| 115 |
+
// sparse tensor (such an operation will behave differently from an equivalent
|
| 116 |
+
// dense tensor's resize method, and for API consistency we don't support it).
|
| 117 |
+
// 3. When we attempt to shrink the size of any of the dense dimensions on a
|
| 118 |
+
// non-empty sparse tensor (such an operation will behave differently from an
|
| 119 |
+
// equivalent dense tensor's resize method, and for API consistency we don't
|
| 120 |
+
// support it).
|
| 121 |
+
// 4. When we attempt to shrink the size of any of the sparse dimensions on a
|
| 122 |
+
// non-empty sparse tensor (this could make some of the stored indices
|
| 123 |
+
// out-of-bound and thus unsafe).
|
| 124 |
+
template <typename T>
|
| 125 |
+
void _resize_(int64_t sparse_dim, int64_t dense_dim, ArrayRef<T> size) {
|
| 126 |
+
TORCH_CHECK(
|
| 127 |
+
allow_tensor_metadata_change(),
|
| 128 |
+
"resize_ ",
|
| 129 |
+
err_msg_tensor_metadata_change_not_allowed);
|
| 130 |
+
TORCH_CHECK(
|
| 131 |
+
!has_symbolic_sizes_strides_,
|
| 132 |
+
"resize_ called on tensor with symbolic shape")
|
| 133 |
+
TORCH_CHECK(
|
| 134 |
+
sparse_dim + dense_dim == static_cast<int64_t>(size.size()),
|
| 135 |
+
"number of dimensions must be sparse_dim (",
|
| 136 |
+
sparse_dim,
|
| 137 |
+
") + dense_dim (",
|
| 138 |
+
dense_dim,
|
| 139 |
+
"), but got ",
|
| 140 |
+
size.size());
|
| 141 |
+
if (nnz() > 0) {
|
| 142 |
+
auto alt_options_msg =
|
| 143 |
+
"You could try the following options:\n\
|
| 144 |
+
1. If you need an empty sparse tensor of this size, call `x = torch.sparse_coo_tensor(size)`.\n\
|
| 145 |
+
2. If you need to resize this tensor, you have the following options:\n\
|
| 146 |
+
1. For both sparse and dense dimensions, keep the number of them constant and the size of them non-shrinking, and then try the same call again.\n\
|
| 147 |
+
2. Or, create a new sparse tensor with the correct indices and values from this sparse tensor.";
|
| 148 |
+
|
| 149 |
+
TORCH_CHECK(
|
| 150 |
+
sparse_dim == sparse_dim_,
|
| 151 |
+
"changing the number of sparse dimensions (from ",
|
| 152 |
+
sparse_dim_,
|
| 153 |
+
" to ",
|
| 154 |
+
sparse_dim,
|
| 155 |
+
") on a non-empty sparse tensor is not supported.\n",
|
| 156 |
+
alt_options_msg);
|
| 157 |
+
|
| 158 |
+
TORCH_CHECK(
|
| 159 |
+
dense_dim == dense_dim_,
|
| 160 |
+
"changing the number of dense dimensions (from ",
|
| 161 |
+
dense_dim_,
|
| 162 |
+
" to ",
|
| 163 |
+
dense_dim,
|
| 164 |
+
") on a non-empty sparse tensor is not supported.\n",
|
| 165 |
+
alt_options_msg);
|
| 166 |
+
|
| 167 |
+
bool shrinking_sparse_dims = false;
|
| 168 |
+
bool shrinking_dense_dim = false;
|
| 169 |
+
auto sparse_size_original = generic_sizes<T>().slice(0, sparse_dim);
|
| 170 |
+
auto sparse_size_new = size.slice(0, sparse_dim);
|
| 171 |
+
for (const auto i : c10::irange(sparse_dim)) {
|
| 172 |
+
if (sparse_size_new[i] < sparse_size_original[i]) {
|
| 173 |
+
shrinking_sparse_dims = true;
|
| 174 |
+
break;
|
| 175 |
+
}
|
| 176 |
+
}
|
| 177 |
+
auto dense_size_original = generic_sizes<T>().slice(sparse_dim);
|
| 178 |
+
auto dense_size_new = size.slice(sparse_dim);
|
| 179 |
+
for (const auto i : c10::irange(dense_dim)) {
|
| 180 |
+
if (dense_size_new[i] < dense_size_original[i]) {
|
| 181 |
+
shrinking_dense_dim = true;
|
| 182 |
+
break;
|
| 183 |
+
}
|
| 184 |
+
}
|
| 185 |
+
|
| 186 |
+
TORCH_CHECK(
|
| 187 |
+
!shrinking_sparse_dims,
|
| 188 |
+
"shrinking the size of sparse dimensions (from ",
|
| 189 |
+
sparse_size_original,
|
| 190 |
+
" to ",
|
| 191 |
+
sparse_size_new,
|
| 192 |
+
") on a non-empty sparse tensor is not supported.\n",
|
| 193 |
+
alt_options_msg);
|
| 194 |
+
|
| 195 |
+
TORCH_CHECK(
|
| 196 |
+
!shrinking_dense_dim,
|
| 197 |
+
"shrinking the size of dense dimensions (from ",
|
| 198 |
+
dense_size_original,
|
| 199 |
+
" to ",
|
| 200 |
+
dense_size_new,
|
| 201 |
+
") on a non-empty sparse tensor is not supported.\n",
|
| 202 |
+
alt_options_msg);
|
| 203 |
+
}
|
| 204 |
+
|
| 205 |
+
auto sizes_and_strides = generic_sizes<T>();
|
| 206 |
+
const bool size_equals_sizes = std::equal(
|
| 207 |
+
size.begin(),
|
| 208 |
+
size.end(),
|
| 209 |
+
sizes_and_strides.begin(),
|
| 210 |
+
sizes_and_strides.end());
|
| 211 |
+
if ((!size_equals_sizes) || (sparse_dim != sparse_dim_) ||
|
| 212 |
+
(dense_dim != dense_dim_)) {
|
| 213 |
+
auto nnz = at::symint::sizes<T>(values())[0];
|
| 214 |
+
std::vector<T> values_size = {nnz};
|
| 215 |
+
auto dense_size = size.slice(sparse_dim);
|
| 216 |
+
values_size.insert(
|
| 217 |
+
values_size.end(), dense_size.begin(), dense_size.end());
|
| 218 |
+
at::symint::resize_<T>(values_, values_size);
|
| 219 |
+
at::symint::resize_<T>(indices_, {T(sparse_dim), nnz});
|
| 220 |
+
}
|
| 221 |
+
|
| 222 |
+
if (!size_equals_sizes) {
|
| 223 |
+
set_sizes_and_strides(size, std::vector<T>(size.size()));
|
| 224 |
+
}
|
| 225 |
+
sparse_dim_ = sparse_dim;
|
| 226 |
+
dense_dim_ = dense_dim;
|
| 227 |
+
refresh_numel();
|
| 228 |
+
}
|
| 229 |
+
|
| 230 |
+
void resize_(int64_t sparse_dim, int64_t dense_dim, ArrayRef<int64_t> size) {
|
| 231 |
+
return _resize_(sparse_dim, dense_dim, size);
|
| 232 |
+
}
|
| 233 |
+
|
| 234 |
+
void resize_(
|
| 235 |
+
int64_t sparse_dim,
|
| 236 |
+
int64_t dense_dim,
|
| 237 |
+
ArrayRef<c10::SymInt> size) {
|
| 238 |
+
return _resize_(sparse_dim, dense_dim, size);
|
| 239 |
+
}
|
| 240 |
+
|
| 241 |
+
// NOTE: this function will resize the sparse tensor and also set `indices`
|
| 242 |
+
// and `values` to empty.
|
| 243 |
+
void resize_and_clear_(
|
| 244 |
+
int64_t sparse_dim,
|
| 245 |
+
int64_t dense_dim,
|
| 246 |
+
IntArrayRef size) {
|
| 247 |
+
TORCH_CHECK(
|
| 248 |
+
allow_tensor_metadata_change(),
|
| 249 |
+
"resize_and_clear_ ",
|
| 250 |
+
err_msg_tensor_metadata_change_not_allowed);
|
| 251 |
+
TORCH_CHECK(
|
| 252 |
+
!has_symbolic_sizes_strides_,
|
| 253 |
+
"resize_and_clear_ called on tensor with symbolic shape")
|
| 254 |
+
TORCH_CHECK(
|
| 255 |
+
sparse_dim + dense_dim == static_cast<int64_t>(size.size()),
|
| 256 |
+
"number of dimensions must be sparse_dim (",
|
| 257 |
+
sparse_dim,
|
| 258 |
+
") + dense_dim (",
|
| 259 |
+
dense_dim,
|
| 260 |
+
"), but got ",
|
| 261 |
+
size.size());
|
| 262 |
+
|
| 263 |
+
set_sizes_and_strides(size, std::vector<int64_t>(size.size()));
|
| 264 |
+
sparse_dim_ = sparse_dim;
|
| 265 |
+
dense_dim_ = dense_dim;
|
| 266 |
+
|
| 267 |
+
auto empty_indices = at::empty({sparse_dim, 0}, indices().options());
|
| 268 |
+
std::vector<int64_t> values_size = {0};
|
| 269 |
+
auto dense_size = sizes().slice(sparse_dim);
|
| 270 |
+
values_size.insert(values_size.end(), dense_size.begin(), dense_size.end());
|
| 271 |
+
auto empty_values = at::empty(values_size, values().options());
|
| 272 |
+
set_indices_and_values_unsafe(empty_indices, empty_values);
|
| 273 |
+
refresh_numel();
|
| 274 |
+
}
|
| 275 |
+
|
| 276 |
+
void set_coalesced(bool coalesced) {
|
| 277 |
+
TORCH_CHECK(
|
| 278 |
+
allow_tensor_metadata_change(),
|
| 279 |
+
"set_coalesced ",
|
| 280 |
+
err_msg_tensor_metadata_change_not_allowed);
|
| 281 |
+
coalesced_ = coalesced;
|
| 282 |
+
}
|
| 283 |
+
|
| 284 |
+
// NOTE: this function is only used internally and not exposed to Python
|
| 285 |
+
// frontend
|
| 286 |
+
void set_nnz_and_narrow(int64_t new_nnz) {
|
| 287 |
+
TORCH_CHECK(
|
| 288 |
+
allow_tensor_metadata_change(),
|
| 289 |
+
"set_nnz_and_narrow ",
|
| 290 |
+
err_msg_tensor_metadata_change_not_allowed);
|
| 291 |
+
AT_ASSERT(new_nnz <= nnz());
|
| 292 |
+
indices_ = indices_.narrow(1, 0, new_nnz);
|
| 293 |
+
values_ = values_.narrow(0, 0, new_nnz);
|
| 294 |
+
if (new_nnz < 2) {
|
| 295 |
+
coalesced_ = true;
|
| 296 |
+
}
|
| 297 |
+
}
|
| 298 |
+
|
| 299 |
+
// Takes indices and values and directly puts them into the sparse tensor, no
|
| 300 |
+
// copy. NOTE: this function is unsafe because it doesn't check whether any
|
| 301 |
+
// indices are out of boundaries of `sizes`, so it should ONLY be used where
|
| 302 |
+
// we know that the indices are guaranteed to be within bounds. This used to
|
| 303 |
+
// be called THSTensor_(_move) NB: This used to be able to avoid a refcount
|
| 304 |
+
// bump, but I was too lazy to make it happen
|
| 305 |
+
void set_indices_and_values_unsafe(
|
| 306 |
+
const Tensor& indices,
|
| 307 |
+
const Tensor& values);
|
| 308 |
+
|
| 309 |
+
/**
|
| 310 |
+
* Return a TensorImpl that is a shallow-copy of this TensorImpl.
|
| 311 |
+
*
|
| 312 |
+
* For usage of `version_counter` and `allow_tensor_metadata_change`,
|
| 313 |
+
* see NOTE [ TensorImpl Shallow-Copying ].
|
| 314 |
+
*/
|
| 315 |
+
c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach(
|
| 316 |
+
const c10::VariableVersion& version_counter,
|
| 317 |
+
bool allow_tensor_metadata_change) const override {
|
| 318 |
+
auto impl = c10::make_intrusive<SparseTensorImpl>(key_set(), dtype());
|
| 319 |
+
copy_tensor_metadata(
|
| 320 |
+
/*src_sparse_impl=*/this,
|
| 321 |
+
/*dest_sparse_impl=*/impl.get(),
|
| 322 |
+
/*version_counter=*/version_counter,
|
| 323 |
+
/*allow_tensor_metadata_change=*/allow_tensor_metadata_change);
|
| 324 |
+
impl->refresh_numel();
|
| 325 |
+
return impl;
|
| 326 |
+
}
|
| 327 |
+
|
| 328 |
+
/**
|
| 329 |
+
* Return a TensorImpl that is a shallow-copy of this TensorImpl.
|
| 330 |
+
*
|
| 331 |
+
* For usage of `version_counter` and `allow_tensor_metadata_change`,
|
| 332 |
+
* see NOTE [ TensorImpl Shallow-Copying ].
|
| 333 |
+
*/
|
| 334 |
+
c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach(
|
| 335 |
+
c10::VariableVersion&& version_counter,
|
| 336 |
+
bool allow_tensor_metadata_change) const override {
|
| 337 |
+
auto impl = c10::make_intrusive<SparseTensorImpl>(key_set(), dtype());
|
| 338 |
+
copy_tensor_metadata(
|
| 339 |
+
/*src_sparse_impl=*/this,
|
| 340 |
+
/*dest_sparse_impl=*/impl.get(),
|
| 341 |
+
/*version_counter=*/std::move(version_counter),
|
| 342 |
+
/*allow_tensor_metadata_change=*/allow_tensor_metadata_change);
|
| 343 |
+
impl->refresh_numel();
|
| 344 |
+
return impl;
|
| 345 |
+
}
|
| 346 |
+
|
| 347 |
+
/**
|
| 348 |
+
* Shallow-copies data from another TensorImpl into this TensorImpl.
|
| 349 |
+
*
|
| 350 |
+
* For why this function doesn't check this TensorImpl's
|
| 351 |
+
* `allow_tensor_metadata_change_`, see NOTE [ TensorImpl Shallow-Copying ].
|
| 352 |
+
*/
|
| 353 |
+
void shallow_copy_from(const c10::intrusive_ptr<TensorImpl>& impl) override {
|
| 354 |
+
AT_ASSERT(has_compatible_shallow_copy_type(impl->key_set()));
|
| 355 |
+
auto sparse_impl = static_cast<const SparseTensorImpl*>(impl.get());
|
| 356 |
+
copy_tensor_metadata(
|
| 357 |
+
/*src_sparse_impl=*/sparse_impl,
|
| 358 |
+
/*dest_sparse_impl=*/this,
|
| 359 |
+
/*version_counter=*/version_counter(),
|
| 360 |
+
/*allow_tensor_metadata_change=*/allow_tensor_metadata_change());
|
| 361 |
+
refresh_numel();
|
| 362 |
+
}
|
| 363 |
+
|
| 364 |
+
private:
|
| 365 |
+
explicit SparseTensorImpl(
|
| 366 |
+
at::DispatchKeySet,
|
| 367 |
+
const caffe2::TypeMeta,
|
| 368 |
+
at::Tensor indices,
|
| 369 |
+
at::Tensor values);
|
| 370 |
+
|
| 371 |
+
/**
|
| 372 |
+
* Copy the tensor metadata fields (e.g. sizes / strides / storage pointer /
|
| 373 |
+
* storage_offset) from one TensorImpl to another TensorImpl.
|
| 374 |
+
*
|
| 375 |
+
* For usage of `version_counter` and `allow_tensor_metadata_change`, see NOTE
|
| 376 |
+
* [ TensorImpl Shallow-Copying ].
|
| 377 |
+
*/
|
| 378 |
+
static void copy_tensor_metadata(
|
| 379 |
+
const SparseTensorImpl* src_sparse_impl,
|
| 380 |
+
SparseTensorImpl* dest_sparse_impl,
|
| 381 |
+
c10::VariableVersion version_counter,
|
| 382 |
+
bool allow_tensor_metadata_change) {
|
| 383 |
+
TensorImpl::copy_tensor_metadata(
|
| 384 |
+
src_sparse_impl,
|
| 385 |
+
dest_sparse_impl,
|
| 386 |
+
std::move(version_counter),
|
| 387 |
+
allow_tensor_metadata_change);
|
| 388 |
+
|
| 389 |
+
// Sparse-specific fields
|
| 390 |
+
dest_sparse_impl->sparse_dim_ = src_sparse_impl->sparse_dim();
|
| 391 |
+
dest_sparse_impl->dense_dim_ = src_sparse_impl->dense_dim();
|
| 392 |
+
dest_sparse_impl->indices_ = src_sparse_impl->indices();
|
| 393 |
+
dest_sparse_impl->values_ = src_sparse_impl->values();
|
| 394 |
+
dest_sparse_impl->coalesced_ = src_sparse_impl->coalesced();
|
| 395 |
+
}
|
| 396 |
+
|
| 397 |
+
const char* tensorimpl_type_name() const override;
|
| 398 |
+
};
|
| 399 |
+
|
| 400 |
+
} // namespace at
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/StorageUtils.h
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <c10/core/Storage.h>
|
| 4 |
+
#include <c10/core/StorageImpl.h>
|
| 5 |
+
#include <c10/util/intrusive_ptr.h>
|
| 6 |
+
|
| 7 |
+
namespace at {
|
| 8 |
+
|
| 9 |
+
class TensorBase;
|
| 10 |
+
|
| 11 |
+
// Here we define a series of utils to create/manipulate ATen backed
|
| 12 |
+
// c10 storage implementations.
|
| 13 |
+
|
| 14 |
+
/**
|
| 15 |
+
* Create a new shared memory storage impl managed by file descriptor
|
| 16 |
+
*
|
| 17 |
+
* @param size size in bytes
|
| 18 |
+
*/
|
| 19 |
+
C10_EXPORT c10::intrusive_ptr<c10::StorageImpl> new_shm_fd_storage(size_t size);
|
| 20 |
+
|
| 21 |
+
/**
|
| 22 |
+
* Copy src to dst
|
| 23 |
+
* Caller must guarantee the validness of the storage objects
|
| 24 |
+
* during the entire copy process, esp. when it's async.
|
| 25 |
+
*
|
| 26 |
+
* This can probably live in c10 namespace later if needed,
|
| 27 |
+
* but for now keep it in at to keep implementation simple.
|
| 28 |
+
*
|
| 29 |
+
* @param dst dst tensor
|
| 30 |
+
* @param src src tensor
|
| 31 |
+
* @param non_blocking (default false) whether this operation blocks caller
|
| 32 |
+
*/
|
| 33 |
+
C10_EXPORT void storage_copy(
|
| 34 |
+
c10::Storage& dst,
|
| 35 |
+
const c10::Storage& src,
|
| 36 |
+
bool non_blocking = false);
|
| 37 |
+
|
| 38 |
+
/**
|
| 39 |
+
* In place change the storage to shm based.
|
| 40 |
+
*
|
| 41 |
+
* This is only applicable to CPU tensors not already shared.
|
| 42 |
+
* Otherwise, it's a no op to mirror the THP tensor behavior:
|
| 43 |
+
* https://pytorch.org/docs/stable/generated/torch.Tensor.share_memory_.html
|
| 44 |
+
*
|
| 45 |
+
* @param t a tensor
|
| 46 |
+
*/
|
| 47 |
+
C10_EXPORT void share_memory_(TensorBase& t);
|
| 48 |
+
|
| 49 |
+
} // namespace at
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/TensorIndexing.h
ADDED
|
@@ -0,0 +1,735 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/ExpandUtils.h>
|
| 4 |
+
#include <ATen/ScalarOps.h>
|
| 5 |
+
#include <ATen/core/Tensor.h>
|
| 6 |
+
#include <ATen/core/TensorBody.h>
|
| 7 |
+
#include <c10/core/SymInt.h>
|
| 8 |
+
#include <c10/util/Optional.h>
|
| 9 |
+
#include <c10/util/irange.h>
|
| 10 |
+
|
| 11 |
+
#ifndef AT_PER_OPERATOR_HEADERS
|
| 12 |
+
#include <ATen/Functions.h>
|
| 13 |
+
#include <ATen/NativeFunctions.h>
|
| 14 |
+
#else
|
| 15 |
+
#include <ATen/ops/alias.h>
|
| 16 |
+
#include <ATen/ops/empty.h>
|
| 17 |
+
#include <ATen/ops/scalar_tensor.h>
|
| 18 |
+
#include <ATen/ops/zeros.h>
|
| 19 |
+
#endif
|
| 20 |
+
|
| 21 |
+
#include <ATen/core/List.h>
|
| 22 |
+
|
| 23 |
+
#include <utility>
|
| 24 |
+
|
| 25 |
+
namespace at::indexing {
|
| 26 |
+
|
| 27 |
+
constexpr int64_t INDEX_MIN = c10::SymInt::min_representable_int();
|
| 28 |
+
constexpr int64_t INDEX_MAX = -(INDEX_MIN + 1);
|
| 29 |
+
|
| 30 |
+
enum class TensorIndexType { None, Ellipsis, SymInt, Boolean, Slice, Tensor };
|
| 31 |
+
|
| 32 |
+
constexpr c10::nullopt_t None = c10::nullopt;
|
| 33 |
+
|
| 34 |
+
struct TORCH_API EllipsisIndexType final {
|
| 35 |
+
EllipsisIndexType() = default;
|
| 36 |
+
};
|
| 37 |
+
TORCH_API extern const EllipsisIndexType Ellipsis;
|
| 38 |
+
|
| 39 |
+
struct TORCH_API Slice final {
|
| 40 |
+
public:
|
| 41 |
+
Slice(
|
| 42 |
+
c10::optional<c10::SymInt> start_index = c10::nullopt,
|
| 43 |
+
c10::optional<c10::SymInt> stop_index = c10::nullopt,
|
| 44 |
+
c10::optional<c10::SymInt> step_index = c10::nullopt) {
|
| 45 |
+
if (!step_index.has_value()) {
|
| 46 |
+
step_ = c10::SymInt(1);
|
| 47 |
+
} else {
|
| 48 |
+
step_ = std::move(step_index).value();
|
| 49 |
+
}
|
| 50 |
+
|
| 51 |
+
TORCH_CHECK_VALUE(step_ != 0, "slice step cannot be zero");
|
| 52 |
+
|
| 53 |
+
if (!start_index.has_value()) {
|
| 54 |
+
start_ = c10::SymInt(step_ < 0 ? INDEX_MAX : 0);
|
| 55 |
+
} else {
|
| 56 |
+
start_ = std::move(start_index).value();
|
| 57 |
+
}
|
| 58 |
+
|
| 59 |
+
if (!stop_index.has_value()) {
|
| 60 |
+
stop_ = c10::SymInt(step_ < 0 ? INDEX_MIN : INDEX_MAX);
|
| 61 |
+
} else {
|
| 62 |
+
stop_ = std::move(stop_index).value();
|
| 63 |
+
}
|
| 64 |
+
}
|
| 65 |
+
|
| 66 |
+
inline c10::SymInt start() const {
|
| 67 |
+
return start_;
|
| 68 |
+
}
|
| 69 |
+
|
| 70 |
+
inline c10::SymInt stop() const {
|
| 71 |
+
return stop_;
|
| 72 |
+
}
|
| 73 |
+
|
| 74 |
+
inline c10::SymInt step() const {
|
| 75 |
+
return step_;
|
| 76 |
+
}
|
| 77 |
+
|
| 78 |
+
private:
|
| 79 |
+
c10::SymInt start_;
|
| 80 |
+
c10::SymInt stop_;
|
| 81 |
+
c10::SymInt step_;
|
| 82 |
+
};
|
| 83 |
+
|
| 84 |
+
TORCH_API std::ostream& operator<<(std::ostream& stream, const Slice& slice);
|
| 85 |
+
|
| 86 |
+
// `at::indexing::TensorIndex` is used for converting C++ tensor indices such as
|
| 87 |
+
// `{None, "...", Ellipsis, 0, true, Slice(1, None, 2), torch::tensor({1, 2})}`
|
| 88 |
+
// into its equivalent `std::vector<TensorIndex>`, so that further tensor
|
| 89 |
+
// indexing operations can be performed using the supplied indices.
|
| 90 |
+
//
|
| 91 |
+
// There is one-to-one correspondence between Python and C++ tensor index types:
|
| 92 |
+
// Python | C++
|
| 93 |
+
// -----------------------------------------------------
|
| 94 |
+
// `None` | `at::indexing::None`
|
| 95 |
+
// `Ellipsis` | `at::indexing::Ellipsis`
|
| 96 |
+
// `...` | `"..."`
|
| 97 |
+
// `123` | `123`
|
| 98 |
+
// `True` / `False` | `true` / `false`
|
| 99 |
+
// `:` | `Slice()` / `Slice(None, None)`
|
| 100 |
+
// `::` | `Slice()` / `Slice(None, None, None)`
|
| 101 |
+
// `1:` | `Slice(1, None)`
|
| 102 |
+
// `1::` | `Slice(1, None, None)`
|
| 103 |
+
// `:3` | `Slice(None, 3)`
|
| 104 |
+
// `:3:` | `Slice(None, 3, None)`
|
| 105 |
+
// `::2` | `Slice(None, None, 2)`
|
| 106 |
+
// `1:3` | `Slice(1, 3)`
|
| 107 |
+
// `1::2` | `Slice(1, None, 2)`
|
| 108 |
+
// `:3:2` | `Slice(None, 3, 2)`
|
| 109 |
+
// `1:3:2` | `Slice(1, 3, 2)`
|
| 110 |
+
// `torch.tensor([1, 2])`) | `torch::tensor({1, 2})`
|
| 111 |
+
struct TORCH_API TensorIndex final {
|
| 112 |
+
// Case 1: `at::indexing::None`
|
| 113 |
+
TensorIndex(c10::nullopt_t) : type_(TensorIndexType::None) {}
|
| 114 |
+
|
| 115 |
+
// Case 2: "..." / `at::indexing::Ellipsis`
|
| 116 |
+
TensorIndex(at::indexing::EllipsisIndexType)
|
| 117 |
+
: type_(TensorIndexType::Ellipsis) {}
|
| 118 |
+
TensorIndex(const char* str) : TensorIndex(at::indexing::Ellipsis) {
|
| 119 |
+
TORCH_CHECK_VALUE(
|
| 120 |
+
strcmp(str, "...") == 0,
|
| 121 |
+
"Expected \"...\" to represent an ellipsis index, but got \"",
|
| 122 |
+
str,
|
| 123 |
+
"\"");
|
| 124 |
+
}
|
| 125 |
+
|
| 126 |
+
// Case 3: (Sym) Integer value
|
| 127 |
+
TensorIndex(SymInt integer)
|
| 128 |
+
: integer_(std::move(integer)), type_(TensorIndexType::SymInt) {}
|
| 129 |
+
TensorIndex(int64_t integer) : TensorIndex(SymInt(integer)) {}
|
| 130 |
+
TensorIndex(int integer) : TensorIndex(SymInt(integer)) {}
|
| 131 |
+
|
| 132 |
+
// Case 4: Boolean value
|
| 133 |
+
template <class T, class = std::enable_if_t<std::is_same_v<bool, T>>>
|
| 134 |
+
TensorIndex(T boolean) : boolean_(boolean), type_(TensorIndexType::Boolean) {}
|
| 135 |
+
|
| 136 |
+
// Case 5: Slice represented in `at::indexing::Slice` form
|
| 137 |
+
TensorIndex(Slice slice)
|
| 138 |
+
: slice_(std::move(slice)), type_(TensorIndexType::Slice) {}
|
| 139 |
+
|
| 140 |
+
// Case 6: Tensor value
|
| 141 |
+
TensorIndex(Tensor tensor)
|
| 142 |
+
: tensor_(std::move(tensor)), type_(TensorIndexType::Tensor) {}
|
| 143 |
+
|
| 144 |
+
inline bool is_none() const {
|
| 145 |
+
return type_ == TensorIndexType::None;
|
| 146 |
+
}
|
| 147 |
+
|
| 148 |
+
inline bool is_ellipsis() const {
|
| 149 |
+
return type_ == TensorIndexType::Ellipsis;
|
| 150 |
+
}
|
| 151 |
+
|
| 152 |
+
inline bool is_integer() const {
|
| 153 |
+
return type_ == TensorIndexType::SymInt;
|
| 154 |
+
}
|
| 155 |
+
|
| 156 |
+
inline SymInt integer() const {
|
| 157 |
+
return integer_;
|
| 158 |
+
}
|
| 159 |
+
|
| 160 |
+
inline bool is_boolean() const {
|
| 161 |
+
return type_ == TensorIndexType::Boolean;
|
| 162 |
+
}
|
| 163 |
+
|
| 164 |
+
inline bool boolean() const {
|
| 165 |
+
return boolean_;
|
| 166 |
+
}
|
| 167 |
+
|
| 168 |
+
inline bool is_slice() const {
|
| 169 |
+
return type_ == TensorIndexType::Slice;
|
| 170 |
+
}
|
| 171 |
+
|
| 172 |
+
inline const Slice& slice() const {
|
| 173 |
+
return slice_;
|
| 174 |
+
}
|
| 175 |
+
|
| 176 |
+
inline bool is_tensor() const {
|
| 177 |
+
return type_ == TensorIndexType::Tensor;
|
| 178 |
+
}
|
| 179 |
+
|
| 180 |
+
inline const Tensor& tensor() const {
|
| 181 |
+
return tensor_;
|
| 182 |
+
}
|
| 183 |
+
|
| 184 |
+
private:
|
| 185 |
+
SymInt integer_ = 0;
|
| 186 |
+
bool boolean_ = false;
|
| 187 |
+
Slice slice_;
|
| 188 |
+
Tensor tensor_;
|
| 189 |
+
TensorIndexType type_;
|
| 190 |
+
};
|
| 191 |
+
|
| 192 |
+
TORCH_API std::ostream& operator<<(
|
| 193 |
+
std::ostream& stream,
|
| 194 |
+
const TensorIndex& tensor_index);
|
| 195 |
+
TORCH_API std::ostream& operator<<(
|
| 196 |
+
std::ostream& stream,
|
| 197 |
+
const std::vector<TensorIndex>& tensor_indices);
|
| 198 |
+
|
| 199 |
+
namespace impl {
|
| 200 |
+
static inline Tensor applySlice(
|
| 201 |
+
const Tensor& self,
|
| 202 |
+
int64_t dim,
|
| 203 |
+
c10::SymInt start,
|
| 204 |
+
c10::SymInt stop,
|
| 205 |
+
c10::SymInt step,
|
| 206 |
+
bool disable_slice_optimization,
|
| 207 |
+
const at::Device& self_device,
|
| 208 |
+
const c10::optional<SymIntArrayRef>& self_sizes) {
|
| 209 |
+
// TODO: implement negative step
|
| 210 |
+
TORCH_CHECK_VALUE(step > 0, "step must be greater than zero");
|
| 211 |
+
|
| 212 |
+
// See NOTE [nested tensor size for indexing]
|
| 213 |
+
if (self_sizes.has_value()) {
|
| 214 |
+
// Skip this optimization if we are tracing, as the trace may be polymorphic
|
| 215 |
+
// over the shape of the `self` tensor, and we still want to record
|
| 216 |
+
// the slice.
|
| 217 |
+
SymInt length = (self_device == at::kCPU || self_device == at::kCUDA)
|
| 218 |
+
? (*self_sizes)[dim]
|
| 219 |
+
: self.sym_size(dim);
|
| 220 |
+
if (!disable_slice_optimization &&
|
| 221 |
+
TORCH_GUARD_SIZE_OBLIVIOUS(start.sym_eq(0)) && length == stop &&
|
| 222 |
+
step == 1) {
|
| 223 |
+
return self;
|
| 224 |
+
}
|
| 225 |
+
}
|
| 226 |
+
return self.slice_symint(
|
| 227 |
+
dim, std::move(start), std::move(stop), std::move(step));
|
| 228 |
+
}
|
| 229 |
+
|
| 230 |
+
static inline Tensor applySelect(
|
| 231 |
+
const Tensor& self,
|
| 232 |
+
int64_t dim,
|
| 233 |
+
SymInt index,
|
| 234 |
+
int64_t real_dim,
|
| 235 |
+
const at::Device& /*self_device*/,
|
| 236 |
+
const c10::optional<SymIntArrayRef>& self_sizes) {
|
| 237 |
+
// See NOTE [nested tensor size for indexing]
|
| 238 |
+
if (self_sizes.has_value()) {
|
| 239 |
+
auto maybe_index = index.maybe_as_int();
|
| 240 |
+
if (maybe_index.has_value()) {
|
| 241 |
+
TORCH_CHECK_INDEX(
|
| 242 |
+
!(maybe_index.value() == 0 && dim == 0 && self_sizes->empty()),
|
| 243 |
+
"invalid index of a 0-dim tensor. ",
|
| 244 |
+
"Use `tensor.item()` in Python or `tensor.item<T>()` in C++ to convert a 0-dim tensor to a number");
|
| 245 |
+
}
|
| 246 |
+
|
| 247 |
+
auto size = (*self_sizes)[dim];
|
| 248 |
+
// Note: `size >= -index` is not equivalent to `size > -1 - index` if index
|
| 249 |
+
// is INT64_MIN For std::numeric_limits<int64_t>::min() result of unary
|
| 250 |
+
// minus is undefined by the standard but in practice is equal to self. On
|
| 251 |
+
// the other hand, indexing wraping is valid for all negative int64_t
|
| 252 |
+
// values, as x[INT64_MIN] is the same as x[INT64_MAX]
|
| 253 |
+
TORCH_CHECK_INDEX(
|
| 254 |
+
size > -1 - index && size > index,
|
| 255 |
+
"index ",
|
| 256 |
+
index,
|
| 257 |
+
" is out of bounds for dimension ",
|
| 258 |
+
real_dim,
|
| 259 |
+
" with size ",
|
| 260 |
+
size);
|
| 261 |
+
}
|
| 262 |
+
|
| 263 |
+
// if the index is negative, do not normalize it because that would fix the
|
| 264 |
+
// index on the current tensor size in the tracer. aten::select also works on
|
| 265 |
+
// negative indices
|
| 266 |
+
return self.select_symint(dim, std::move(index));
|
| 267 |
+
}
|
| 268 |
+
|
| 269 |
+
static inline Tensor boolToIndexingTensorCPUOrCUDA(
|
| 270 |
+
const Tensor& self,
|
| 271 |
+
bool value) {
|
| 272 |
+
// booleans add a dimension of size 1. true indexes this dimension as if 0:,
|
| 273 |
+
// false as empty.
|
| 274 |
+
if (value) {
|
| 275 |
+
return at::empty({1}, self.options().dtype(kLong)).fill_(0.);
|
| 276 |
+
} else {
|
| 277 |
+
return at::empty({0}, self.options().dtype(kLong));
|
| 278 |
+
}
|
| 279 |
+
}
|
| 280 |
+
|
| 281 |
+
static inline Tensor boolToIndexingTensorNonNativeDeviceType(
|
| 282 |
+
const Tensor& self,
|
| 283 |
+
bool value) {
|
| 284 |
+
// booleans add a dimension of size 1. true indexes this dimension as if 0:,
|
| 285 |
+
// false as empty.
|
| 286 |
+
if (value) {
|
| 287 |
+
return at::zeros({1}, self.options().dtype(kLong));
|
| 288 |
+
} else {
|
| 289 |
+
return at::empty({0}, self.options().dtype(kLong));
|
| 290 |
+
}
|
| 291 |
+
}
|
| 292 |
+
|
| 293 |
+
static inline Tensor boolToIndexingTensor(
|
| 294 |
+
const Tensor& self,
|
| 295 |
+
bool value,
|
| 296 |
+
const at::Device& self_device) {
|
| 297 |
+
if (self_device == at::kCPU || self_device == at::kCUDA) {
|
| 298 |
+
return boolToIndexingTensorCPUOrCUDA(self, value);
|
| 299 |
+
} else {
|
| 300 |
+
return boolToIndexingTensorNonNativeDeviceType(self, value);
|
| 301 |
+
}
|
| 302 |
+
}
|
| 303 |
+
|
| 304 |
+
static inline Tensor scalarToTensorNonNativeDeviceType(
|
| 305 |
+
const Scalar& v,
|
| 306 |
+
const TensorOptions& options) {
|
| 307 |
+
return at::scalar_tensor(v, options);
|
| 308 |
+
}
|
| 309 |
+
|
| 310 |
+
static inline void recordTensorIndex(
|
| 311 |
+
const Tensor& tensor,
|
| 312 |
+
std::vector<Tensor>& outIndices,
|
| 313 |
+
int64_t* dim_ptr) {
|
| 314 |
+
// TODO: check scalarType
|
| 315 |
+
outIndices.resize(*dim_ptr + 1);
|
| 316 |
+
outIndices[*dim_ptr] = tensor;
|
| 317 |
+
(*dim_ptr)++;
|
| 318 |
+
};
|
| 319 |
+
|
| 320 |
+
static inline c10::List<c10::optional<Tensor>> typeConvertIndices(
|
| 321 |
+
const Tensor& /*self*/,
|
| 322 |
+
std::vector<Tensor>&& indices) {
|
| 323 |
+
c10::List<c10::optional<Tensor>> converted_inds;
|
| 324 |
+
converted_inds.reserve(indices.size());
|
| 325 |
+
for (auto&& i : std::move(indices)) {
|
| 326 |
+
converted_inds.push_back(std::move(i));
|
| 327 |
+
}
|
| 328 |
+
return converted_inds;
|
| 329 |
+
}
|
| 330 |
+
|
| 331 |
+
// NOTE: Why do we mirror instead of replace the `count_specified_dimensions`
|
| 332 |
+
// function in torch/csrc/autograd/python_variable_indexing.cpp? It's because
|
| 333 |
+
// `count_specified_dimensions` is on the hot path of Python tensor multi-dim
|
| 334 |
+
// indexing (i.e. it's called by `applySlicing` which is called by
|
| 335 |
+
// `THPVariable_getitem` / `THPVariable_setitem` when handling indexing of more
|
| 336 |
+
// than one dimension). If we were to merge the Python/C++
|
| 337 |
+
// `count_specified_dimensions` function, on the Python side we would have to
|
| 338 |
+
// construct a `std::vector` container to be consumed by the C++
|
| 339 |
+
// `count_specified_dimensions` function, which adds 100s of nanoseconds
|
| 340 |
+
// overhead and is undesirable.
|
| 341 |
+
static inline int64_t count_specified_dimensions(
|
| 342 |
+
const ArrayRef<TensorIndex>& indices) {
|
| 343 |
+
// Count the number of indexed dimensions (everything but ellipsis and None)
|
| 344 |
+
int64_t count = 0;
|
| 345 |
+
for (auto& obj : indices) {
|
| 346 |
+
if (obj.is_tensor()) {
|
| 347 |
+
auto& tensor = obj.tensor();
|
| 348 |
+
if (tensor.scalar_type() == kByte || tensor.scalar_type() == kBool) {
|
| 349 |
+
count += tensor.dim();
|
| 350 |
+
} else {
|
| 351 |
+
count++;
|
| 352 |
+
}
|
| 353 |
+
} else if (!obj.is_none() && !obj.is_ellipsis() && !obj.is_boolean()) {
|
| 354 |
+
count++;
|
| 355 |
+
}
|
| 356 |
+
}
|
| 357 |
+
return count;
|
| 358 |
+
}
|
| 359 |
+
} // namespace impl
|
| 360 |
+
|
| 361 |
+
// NOTE: Many functions below are only for consumption from Python indexing
|
| 362 |
+
// implementation, they include:
|
| 363 |
+
//
|
| 364 |
+
// - `Tensor scalarToTensor(...)`
|
| 365 |
+
// - `IntArrayRef slicePrefix1sSize(...)`
|
| 366 |
+
// - `void copy_to(...)`
|
| 367 |
+
// - `Tensor handleDimInMultiDimIndexing(...)`
|
| 368 |
+
// - `Tensor dispatch_index(...)`
|
| 369 |
+
// - `Tensor dispatch_index_put_(...)`
|
| 370 |
+
// - `Tensor get_item(...)`
|
| 371 |
+
// - `void set_item(...)`
|
| 372 |
+
//
|
| 373 |
+
// The rest of the functions are in `at::indexing::impl` namespace, signifying
|
| 374 |
+
// that they shouldn't be used from Python indexing implementation.
|
| 375 |
+
static inline Tensor scalarToTensor(
|
| 376 |
+
const Scalar& v,
|
| 377 |
+
const TensorOptions& options,
|
| 378 |
+
const at::Device& self_device) {
|
| 379 |
+
if (self_device == at::kCPU && !v.isSymbolic()) {
|
| 380 |
+
return at::detail::scalar_tensor_static(
|
| 381 |
+
v, options.dtype_opt()->toScalarType(), self_device);
|
| 382 |
+
} else {
|
| 383 |
+
return impl::scalarToTensorNonNativeDeviceType(v, options);
|
| 384 |
+
}
|
| 385 |
+
}
|
| 386 |
+
|
| 387 |
+
// To match numpy semantics:
|
| 388 |
+
// As a special case for backwards compatibility,
|
| 389 |
+
// strip away unit dimensions from the left of 'src'
|
| 390 |
+
static inline SymIntArrayRef slicePrefix1sSize(const SymIntArrayRef& sizes) {
|
| 391 |
+
size_t first_non1_src = sizes.size();
|
| 392 |
+
for (const auto i : c10::irange(sizes.size())) {
|
| 393 |
+
// Unbacked SymInt has different behavior, but this is sound because
|
| 394 |
+
// failing to slice will only ever cause an error, not divergent
|
| 395 |
+
// behavior
|
| 396 |
+
if (!sizes[i].has_hint() || sizes[i] != 1) {
|
| 397 |
+
first_non1_src = i;
|
| 398 |
+
break;
|
| 399 |
+
}
|
| 400 |
+
}
|
| 401 |
+
|
| 402 |
+
return sizes.slice(first_non1_src);
|
| 403 |
+
}
|
| 404 |
+
|
| 405 |
+
static inline void copy_to(const Tensor& dst, const Tensor& src) {
|
| 406 |
+
if (dst.sym_sizes().equals(src.sym_sizes())) {
|
| 407 |
+
// A shortcut to avoid generating hard-coded constant sizes during tracing.
|
| 408 |
+
// This is not a perfect solution: when src & dst have different shapes,
|
| 409 |
+
// constants will still appear. Users can workaround that case by
|
| 410 |
+
// dst[index..] = src.reshape(..)
|
| 411 |
+
dst.copy_(src);
|
| 412 |
+
return;
|
| 413 |
+
} else if (src.dim() == 0 && src.device().type() == at::kCPU) {
|
| 414 |
+
dst.fill_(src);
|
| 415 |
+
return;
|
| 416 |
+
}
|
| 417 |
+
auto src_view = src.view_symint(slicePrefix1sSize(src.sym_sizes()));
|
| 418 |
+
c10::MaybeOwned<Tensor> b_src = expand_inplace(dst, src_view, "setitem");
|
| 419 |
+
dst.copy_(*b_src);
|
| 420 |
+
}
|
| 421 |
+
|
| 422 |
+
// See NOTE [ Setting `disable_slice_optimization` when calling C++ tensor
|
| 423 |
+
// indexing functions from Python ]
|
| 424 |
+
static inline Tensor handleDimInMultiDimIndexing(
|
| 425 |
+
const Tensor& prev_dim_result,
|
| 426 |
+
const Tensor& original_tensor,
|
| 427 |
+
const TensorIndex& index,
|
| 428 |
+
int64_t* dim_ptr,
|
| 429 |
+
int64_t* specified_dims_ptr,
|
| 430 |
+
int64_t real_dim,
|
| 431 |
+
std::vector<Tensor>& outIndices,
|
| 432 |
+
bool disable_slice_optimization,
|
| 433 |
+
const at::Device& original_tensor_device,
|
| 434 |
+
const c10::optional<SymIntArrayRef>& prev_dim_result_sizes) {
|
| 435 |
+
if (index.is_integer()) {
|
| 436 |
+
return impl::applySelect(
|
| 437 |
+
prev_dim_result,
|
| 438 |
+
*dim_ptr,
|
| 439 |
+
index.integer(),
|
| 440 |
+
real_dim,
|
| 441 |
+
original_tensor_device,
|
| 442 |
+
prev_dim_result_sizes);
|
| 443 |
+
} else if (index.is_slice()) {
|
| 444 |
+
Tensor result = impl::applySlice(
|
| 445 |
+
prev_dim_result,
|
| 446 |
+
*dim_ptr,
|
| 447 |
+
index.slice().start(),
|
| 448 |
+
index.slice().stop(),
|
| 449 |
+
index.slice().step(),
|
| 450 |
+
/*disable_slice_optimization=*/disable_slice_optimization,
|
| 451 |
+
original_tensor_device,
|
| 452 |
+
prev_dim_result_sizes);
|
| 453 |
+
(*dim_ptr)++;
|
| 454 |
+
return result;
|
| 455 |
+
} else if (index.is_ellipsis()) {
|
| 456 |
+
(*dim_ptr) += original_tensor.dim() - (*specified_dims_ptr);
|
| 457 |
+
return prev_dim_result;
|
| 458 |
+
} else if (index.is_none()) {
|
| 459 |
+
Tensor result = prev_dim_result.unsqueeze(*dim_ptr);
|
| 460 |
+
(*dim_ptr)++;
|
| 461 |
+
return result;
|
| 462 |
+
} else if (index.is_boolean()) {
|
| 463 |
+
Tensor result = prev_dim_result.unsqueeze(*dim_ptr);
|
| 464 |
+
impl::recordTensorIndex(
|
| 465 |
+
impl::boolToIndexingTensor(
|
| 466 |
+
result, index.boolean(), original_tensor_device),
|
| 467 |
+
outIndices,
|
| 468 |
+
dim_ptr);
|
| 469 |
+
return result;
|
| 470 |
+
} else if (index.is_tensor()) {
|
| 471 |
+
Tensor result = prev_dim_result;
|
| 472 |
+
const Tensor& tensor = index.tensor();
|
| 473 |
+
auto scalar_type = tensor.scalar_type();
|
| 474 |
+
if (tensor.dim() == 0 &&
|
| 475 |
+
at::isIntegralType(scalar_type, /*includeBool=*/true)) {
|
| 476 |
+
if (scalar_type != at::kByte && scalar_type != at::kBool) {
|
| 477 |
+
result = impl::applySelect(
|
| 478 |
+
result,
|
| 479 |
+
*dim_ptr,
|
| 480 |
+
tensor.item<int64_t>(),
|
| 481 |
+
real_dim,
|
| 482 |
+
original_tensor_device,
|
| 483 |
+
prev_dim_result_sizes);
|
| 484 |
+
} else {
|
| 485 |
+
result = result.unsqueeze(*dim_ptr);
|
| 486 |
+
if (scalar_type == at::kBool) {
|
| 487 |
+
impl::recordTensorIndex(
|
| 488 |
+
impl::boolToIndexingTensor(
|
| 489 |
+
result, tensor.item<bool>() != 0, original_tensor_device),
|
| 490 |
+
outIndices,
|
| 491 |
+
dim_ptr);
|
| 492 |
+
} else {
|
| 493 |
+
impl::recordTensorIndex(
|
| 494 |
+
impl::boolToIndexingTensor(
|
| 495 |
+
result, tensor.item<uint8_t>() != 0, original_tensor_device),
|
| 496 |
+
outIndices,
|
| 497 |
+
dim_ptr);
|
| 498 |
+
}
|
| 499 |
+
}
|
| 500 |
+
} else {
|
| 501 |
+
impl::recordTensorIndex(tensor, outIndices, dim_ptr);
|
| 502 |
+
}
|
| 503 |
+
return result;
|
| 504 |
+
} else {
|
| 505 |
+
TORCH_INTERNAL_ASSERT(false, "Invalid TensorIndex type");
|
| 506 |
+
}
|
| 507 |
+
}
|
| 508 |
+
|
| 509 |
+
namespace impl {
|
| 510 |
+
// This mirrors `applySlicing` in
|
| 511 |
+
// torch/csrc/autograd/python_variable_indexing.cpp
|
| 512 |
+
static inline Tensor applySlicing(
|
| 513 |
+
const Tensor& self,
|
| 514 |
+
const ArrayRef<TensorIndex>& indices,
|
| 515 |
+
std::vector<Tensor>& outIndices,
|
| 516 |
+
bool disable_slice_optimization,
|
| 517 |
+
const at::Device& self_device,
|
| 518 |
+
const c10::optional<SymIntArrayRef>& self_sizes) {
|
| 519 |
+
int64_t dim = 0;
|
| 520 |
+
int64_t specified_dims = impl::count_specified_dimensions(indices);
|
| 521 |
+
|
| 522 |
+
// See NOTE [nested tensor size for indexing]
|
| 523 |
+
if (self_sizes.has_value()) {
|
| 524 |
+
TORCH_CHECK_INDEX(
|
| 525 |
+
specified_dims <= (int64_t)self_sizes->size(),
|
| 526 |
+
"too many indices for tensor of dimension ",
|
| 527 |
+
(int)self_sizes->size());
|
| 528 |
+
}
|
| 529 |
+
|
| 530 |
+
Tensor result = self;
|
| 531 |
+
for (const auto i : c10::irange(indices.size())) {
|
| 532 |
+
auto& obj = indices[i];
|
| 533 |
+
// See NOTE [nested tensor size for indexing]
|
| 534 |
+
c10::optional<SymIntArrayRef> result_sizes = result.is_nested()
|
| 535 |
+
? c10::optional<SymIntArrayRef>(c10::nullopt)
|
| 536 |
+
: c10::optional<SymIntArrayRef>(result.sym_sizes());
|
| 537 |
+
result = handleDimInMultiDimIndexing(
|
| 538 |
+
/*prev_dim_result=*/result,
|
| 539 |
+
/*original_tensor=*/self,
|
| 540 |
+
/*index=*/obj,
|
| 541 |
+
/*dim_ptr=*/&dim,
|
| 542 |
+
/*specified_dims_ptr=*/&specified_dims,
|
| 543 |
+
/*real_dim=*/static_cast<int64_t>(i),
|
| 544 |
+
/*outIndices=*/outIndices,
|
| 545 |
+
/*disable_slice_optimization=*/disable_slice_optimization,
|
| 546 |
+
/*original_tensor_device=*/self_device,
|
| 547 |
+
/*prev_dim_result_sizes=*/result_sizes);
|
| 548 |
+
}
|
| 549 |
+
return result;
|
| 550 |
+
}
|
| 551 |
+
} // namespace impl
|
| 552 |
+
|
| 553 |
+
static inline Tensor dispatch_index(
|
| 554 |
+
const Tensor& self,
|
| 555 |
+
std::vector<Tensor>&& indices) {
|
| 556 |
+
return self.index(impl::typeConvertIndices(self, std::move(indices)));
|
| 557 |
+
}
|
| 558 |
+
|
| 559 |
+
static inline Tensor dispatch_index_put_(
|
| 560 |
+
Tensor& self,
|
| 561 |
+
std::vector<Tensor>&& indices,
|
| 562 |
+
const Tensor& value) {
|
| 563 |
+
return self.index_put_(
|
| 564 |
+
impl::typeConvertIndices(self, std::move(indices)), value);
|
| 565 |
+
}
|
| 566 |
+
|
| 567 |
+
// NOTE [ Setting `disable_slice_optimization` when calling C++ tensor indexing
|
| 568 |
+
// functions from Python ]
|
| 569 |
+
//
|
| 570 |
+
// Question: When should we set `disable_slice_optimization` to `true` when
|
| 571 |
+
// calling C++ tensor indexing functions from Python indexing code?
|
| 572 |
+
//
|
| 573 |
+
// Answer: What "slice optimization" means: when we have a slicing expression
|
| 574 |
+
// like `x[0:5, 0]`, where the sliced tensor was of size 5 in dimension 0, we
|
| 575 |
+
// would skip dispatching the actual slice call as an optimization. However,
|
| 576 |
+
// here are the cases where we DON'T want this optimization:
|
| 577 |
+
//
|
| 578 |
+
// 1. When we are doing 1-D slicing (e.g. `tensor[:]`).
|
| 579 |
+
// Reason: we always return a shallow copy for expressions such as
|
| 580 |
+
// `tensor[:]` / `tensor[...]` / `tensor[:, :]`. (Note that for `tensor[:,
|
| 581 |
+
// :]`, we return an alias of `tensor` by doing the following:
|
| 582 |
+
// ```
|
| 583 |
+
// Tensor sliced = impl::applySlicing(self, indices, tensorIndices,
|
| 584 |
+
// disable_slice_optimization, self_device, self_sizes); if
|
| 585 |
+
// (tensorIndices.empty()) {
|
| 586 |
+
// if (sliced.is_same(self)) {
|
| 587 |
+
// // ensure we return a shallow copy for things like x[...]
|
| 588 |
+
// sliced = at::alias(sliced);
|
| 589 |
+
// }
|
| 590 |
+
// return sliced;
|
| 591 |
+
// }
|
| 592 |
+
// ```)
|
| 593 |
+
// 2. When we are doing JIT tracing.
|
| 594 |
+
// Reason: JIT tracing needs the `self.slice(...)` call to properly trace the
|
| 595 |
+
// slice operation.
|
| 596 |
+
|
| 597 |
+
// This mirrors `THPVariable_getitem` in
|
| 598 |
+
// torch/csrc/autograd/python_variable_indexing.cpp See NOTE [ Setting
|
| 599 |
+
// `disable_slice_optimization` when calling C++ tensor indexing functions from
|
| 600 |
+
// Python ]
|
| 601 |
+
static inline Tensor get_item(
|
| 602 |
+
const Tensor& self,
|
| 603 |
+
const ArrayRef<TensorIndex>& indices,
|
| 604 |
+
bool disable_slice_optimization = false) {
|
| 605 |
+
at::Device self_device = self.device();
|
| 606 |
+
// NOTE [nested tensor size for indexing]
|
| 607 |
+
// nested tensor does not have a size (yet) so for now we represent its size
|
| 608 |
+
// as null may need to be changed after we reach a better solution for nested
|
| 609 |
+
// tensor size
|
| 610 |
+
c10::optional<SymIntArrayRef> self_sizes = self.is_nested()
|
| 611 |
+
? c10::optional<SymIntArrayRef>(c10::nullopt)
|
| 612 |
+
: c10::optional<SymIntArrayRef>(self.sym_sizes());
|
| 613 |
+
|
| 614 |
+
// handle simple types: integers, slices, none, ellipsis, bool
|
| 615 |
+
if (indices.size() == 1) {
|
| 616 |
+
const TensorIndex& index = indices[0];
|
| 617 |
+
if (index.is_integer()) {
|
| 618 |
+
return impl::applySelect(
|
| 619 |
+
self, 0, index.integer(), 0, self_device, self_sizes);
|
| 620 |
+
} else if (index.is_slice()) {
|
| 621 |
+
return impl::applySlice(
|
| 622 |
+
self,
|
| 623 |
+
0,
|
| 624 |
+
index.slice().start(),
|
| 625 |
+
index.slice().stop(),
|
| 626 |
+
index.slice().step(),
|
| 627 |
+
/*disable_slice_optimization=*/true,
|
| 628 |
+
self_device,
|
| 629 |
+
self_sizes);
|
| 630 |
+
} else if (index.is_none()) {
|
| 631 |
+
return self.unsqueeze(0);
|
| 632 |
+
} else if (index.is_ellipsis()) {
|
| 633 |
+
return at::alias(self);
|
| 634 |
+
} else if (index.is_boolean()) {
|
| 635 |
+
Tensor result = self.unsqueeze(0);
|
| 636 |
+
return dispatch_index(
|
| 637 |
+
result,
|
| 638 |
+
std::vector<Tensor>{impl::boolToIndexingTensor(
|
| 639 |
+
result, index.boolean(), self_device)});
|
| 640 |
+
}
|
| 641 |
+
}
|
| 642 |
+
|
| 643 |
+
std::vector<Tensor> tensorIndices;
|
| 644 |
+
Tensor sliced = impl::applySlicing(
|
| 645 |
+
self,
|
| 646 |
+
indices,
|
| 647 |
+
tensorIndices,
|
| 648 |
+
disable_slice_optimization,
|
| 649 |
+
self_device,
|
| 650 |
+
self_sizes);
|
| 651 |
+
if (tensorIndices.empty()) {
|
| 652 |
+
if (sliced.is_same(self)) {
|
| 653 |
+
// ensure we return a shallow copy for things like x[...]
|
| 654 |
+
sliced = at::alias(sliced);
|
| 655 |
+
}
|
| 656 |
+
return sliced;
|
| 657 |
+
}
|
| 658 |
+
|
| 659 |
+
// indexing by tensors ("advanced" indexing)
|
| 660 |
+
return dispatch_index(sliced, std::move(tensorIndices));
|
| 661 |
+
}
|
| 662 |
+
|
| 663 |
+
// This mirrors `THPVariable_setitem` in
|
| 664 |
+
// torch/csrc/autograd/python_variable_indexing.cpp for "the assigned value is a
|
| 665 |
+
// Tensor" case See NOTE [ Setting `disable_slice_optimization` when calling C++
|
| 666 |
+
// tensor indexing functions from Python ]
|
| 667 |
+
static inline void set_item(
|
| 668 |
+
const Tensor& self,
|
| 669 |
+
const ArrayRef<TensorIndex>& indices,
|
| 670 |
+
const Tensor& value,
|
| 671 |
+
bool disable_slice_optimization = false) {
|
| 672 |
+
at::Device self_device = self.device();
|
| 673 |
+
SymIntArrayRef self_sizes = self.sym_sizes();
|
| 674 |
+
|
| 675 |
+
// handle simple types: integers, slices, ellipsis, bool
|
| 676 |
+
if (indices.size() == 1) {
|
| 677 |
+
const TensorIndex& index = indices[0];
|
| 678 |
+
if (index.is_boolean() && !index.boolean()) {
|
| 679 |
+
// do nothing for false (technically we should check the size, but we
|
| 680 |
+
// don't have real 0-sized shapes.
|
| 681 |
+
return;
|
| 682 |
+
} else if (index.is_ellipsis()) {
|
| 683 |
+
copy_to(self, value);
|
| 684 |
+
return;
|
| 685 |
+
} else if (index.is_none() || (index.is_boolean() && index.boolean())) {
|
| 686 |
+
copy_to(self.unsqueeze(0), value);
|
| 687 |
+
return;
|
| 688 |
+
} else if (index.is_integer()) {
|
| 689 |
+
copy_to(
|
| 690 |
+
impl::applySelect(
|
| 691 |
+
self, 0, index.integer(), 0, self_device, self_sizes),
|
| 692 |
+
value);
|
| 693 |
+
return;
|
| 694 |
+
} else if (index.is_slice()) {
|
| 695 |
+
copy_to(
|
| 696 |
+
impl::applySlice(
|
| 697 |
+
self,
|
| 698 |
+
0,
|
| 699 |
+
index.slice().start(),
|
| 700 |
+
index.slice().stop(),
|
| 701 |
+
index.slice().step(),
|
| 702 |
+
/*disable_slice_optimization=*/disable_slice_optimization,
|
| 703 |
+
self_device,
|
| 704 |
+
self_sizes),
|
| 705 |
+
value);
|
| 706 |
+
return;
|
| 707 |
+
}
|
| 708 |
+
}
|
| 709 |
+
|
| 710 |
+
std::vector<Tensor> tensorIndices;
|
| 711 |
+
Tensor sliced = impl::applySlicing(
|
| 712 |
+
self,
|
| 713 |
+
indices,
|
| 714 |
+
tensorIndices,
|
| 715 |
+
disable_slice_optimization,
|
| 716 |
+
self_device,
|
| 717 |
+
self_sizes);
|
| 718 |
+
if (tensorIndices.empty()) {
|
| 719 |
+
copy_to(sliced, value);
|
| 720 |
+
return;
|
| 721 |
+
}
|
| 722 |
+
|
| 723 |
+
SymIntArrayRef valueSizes = value.sym_sizes();
|
| 724 |
+
SymIntArrayRef slicedValueSizes = slicePrefix1sSize(valueSizes);
|
| 725 |
+
Tensor valuesSliced;
|
| 726 |
+
if (!valueSizes.equals(slicedValueSizes)) {
|
| 727 |
+
valuesSliced = value.view_symint(slicedValueSizes);
|
| 728 |
+
} else {
|
| 729 |
+
valuesSliced = value;
|
| 730 |
+
}
|
| 731 |
+
dispatch_index_put_(sliced, std::move(tensorIndices), valuesSliced);
|
| 732 |
+
return;
|
| 733 |
+
}
|
| 734 |
+
|
| 735 |
+
} // namespace at::indexing
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ThreadLocalPythonObjects.h
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <c10/core/SafePyObject.h>
|
| 4 |
+
#include <c10/macros/Macros.h>
|
| 5 |
+
#include <unordered_map>
|
| 6 |
+
|
| 7 |
+
namespace at::impl {
|
| 8 |
+
|
| 9 |
+
struct TORCH_API ThreadLocalPythonObjects {
|
| 10 |
+
static void set(const std::string& key, std::shared_ptr<SafePyObject> value);
|
| 11 |
+
static const std::shared_ptr<SafePyObject>& get(const std::string& key);
|
| 12 |
+
static bool contains(const std::string& key);
|
| 13 |
+
|
| 14 |
+
static const ThreadLocalPythonObjects& get_state();
|
| 15 |
+
static void set_state(ThreadLocalPythonObjects state);
|
| 16 |
+
|
| 17 |
+
private:
|
| 18 |
+
std::unordered_map<std::string, std::shared_ptr<c10::SafePyObject>> obj_dict_;
|
| 19 |
+
};
|
| 20 |
+
|
| 21 |
+
} // namespace at::impl
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ThreadLocalState.h
ADDED
|
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <c10/core/InferenceMode.h>
|
| 4 |
+
#include <c10/core/impl/LocalDispatchKeySet.h>
|
| 5 |
+
#include <c10/util/Exception.h>
|
| 6 |
+
#include <c10/util/ThreadLocalDebugInfo.h>
|
| 7 |
+
|
| 8 |
+
#include <ATen/FuncTorchTLS.h>
|
| 9 |
+
#include <ATen/PythonTorchFunctionTLS.h>
|
| 10 |
+
#include <ATen/SavedTensorHooks.h>
|
| 11 |
+
#include <ATen/ThreadLocalPythonObjects.h>
|
| 12 |
+
#include <ATen/record_function.h>
|
| 13 |
+
#include <c10/core/impl/PythonDispatcherTLS.h>
|
| 14 |
+
#include <c10/core/impl/TorchDispatchModeTLS.h>
|
| 15 |
+
|
| 16 |
+
namespace at {
|
| 17 |
+
|
| 18 |
+
// Thread local state contains values that are preserved across
|
| 19 |
+
// thread boundaries (e.g. at::launch/JIT fork, autograd).
|
| 20 |
+
// Note at::parallel_for doesn't preserve TLS across thread boundaries.
|
| 21 |
+
class TORCH_API ThreadLocalState {
|
| 22 |
+
public:
|
| 23 |
+
// Saves the thread local variables' values and
|
| 24 |
+
// returns them as a ThreadLocalState
|
| 25 |
+
ThreadLocalState();
|
| 26 |
+
|
| 27 |
+
// set_grad_mode - force the value of the grad mode TLS in
|
| 28 |
+
// the current state object. This is used for example in the
|
| 29 |
+
// autograd engine.
|
| 30 |
+
void set_grad_mode(bool enabled);
|
| 31 |
+
|
| 32 |
+
// set_multithreading_enabled - force the value of the multithreadinmaximum
|
| 33 |
+
// threads TLS in
|
| 34 |
+
// the current state object. This is used for example in the
|
| 35 |
+
// autograd engine.
|
| 36 |
+
void set_multithreading_enabled(bool enabled);
|
| 37 |
+
|
| 38 |
+
// Sets thread local variables in the current thread,
|
| 39 |
+
// according to the thread boundary specified
|
| 40 |
+
static void setThreadLocalState(const ThreadLocalState& state);
|
| 41 |
+
|
| 42 |
+
private:
|
| 43 |
+
c10::impl::LocalDispatchKeySet dispatch_key_;
|
| 44 |
+
|
| 45 |
+
// ThreadLocalDebugInfo does not change after being created
|
| 46 |
+
// with DebugInfoGuard
|
| 47 |
+
std::shared_ptr<c10::ThreadLocalDebugInfo> debug_info_;
|
| 48 |
+
|
| 49 |
+
// RecordFunction TLS
|
| 50 |
+
RecordFunctionTLS rf_tls_;
|
| 51 |
+
|
| 52 |
+
// TLS for out-of-tree functorch
|
| 53 |
+
// See NOTE [functorch TLS in pytorch/pytorch] for why this needs to be a
|
| 54 |
+
// pointer (spoiler alert: it's due to the indirection)
|
| 55 |
+
// This needs to be a shared_ptr instead of a unique_ptr because
|
| 56 |
+
// ThreadLocalState is copy-able and does indeed get copied. Maybe we can
|
| 57 |
+
// consider adding an explicit copy constructor for ThreadLocalState in the
|
| 58 |
+
// future but I didn't want to add one just for this.
|
| 59 |
+
std::shared_ptr<const functorch::FuncTorchTLSBase> functorch_tls_;
|
| 60 |
+
|
| 61 |
+
// TLS for AutogradModes
|
| 62 |
+
AutogradState autograd_tls_;
|
| 63 |
+
|
| 64 |
+
// TLS for enable_torch_dispatch_mode
|
| 65 |
+
c10::impl::TorchDispatchModeTLS torch_dispatch_mode_state_;
|
| 66 |
+
|
| 67 |
+
// TLS for enable_python_dispatcher
|
| 68 |
+
c10::impl::PyInterpreter* python_dispatcher_state_;
|
| 69 |
+
|
| 70 |
+
// TLS for __torch_function__ (mode and disable_torch_function)
|
| 71 |
+
at::impl::PythonTorchFunctionTLS python_torch_function_state_;
|
| 72 |
+
|
| 73 |
+
// TLS for saved tensors default hooks
|
| 74 |
+
at::impl::SavedTensorDefaultHooksTLS saved_tensors_default_hooks_state_;
|
| 75 |
+
|
| 76 |
+
bool functionalization_reapply_views_state_;
|
| 77 |
+
|
| 78 |
+
// TLS for arbitrary python objects that is registered via hooks
|
| 79 |
+
at::impl::ThreadLocalPythonObjects saved_objects_;
|
| 80 |
+
|
| 81 |
+
friend class ThreadLocalStateGuard;
|
| 82 |
+
};
|
| 83 |
+
|
| 84 |
+
// Guard to set and reset the thread local state
|
| 85 |
+
class TORCH_API ThreadLocalStateGuard {
|
| 86 |
+
public:
|
| 87 |
+
explicit ThreadLocalStateGuard(const ThreadLocalState& state)
|
| 88 |
+
: prev_state_(ThreadLocalState()) {
|
| 89 |
+
// set the given state across the thread boundary
|
| 90 |
+
ThreadLocalState::setThreadLocalState(state);
|
| 91 |
+
}
|
| 92 |
+
|
| 93 |
+
~ThreadLocalStateGuard() {
|
| 94 |
+
// restore previously set variables
|
| 95 |
+
ThreadLocalState::setThreadLocalState(prev_state_);
|
| 96 |
+
}
|
| 97 |
+
|
| 98 |
+
private:
|
| 99 |
+
// NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
|
| 100 |
+
const ThreadLocalState prev_state_;
|
| 101 |
+
};
|
| 102 |
+
|
| 103 |
+
template <typename T>
|
| 104 |
+
auto wrapPropagateTLSState(T callback) {
|
| 105 |
+
return [tls_state = ThreadLocalState(),
|
| 106 |
+
callback = std::move(callback)](auto&&... args) {
|
| 107 |
+
ThreadLocalStateGuard g(tls_state);
|
| 108 |
+
// Propagate value returned by callback().
|
| 109 |
+
return callback(std::forward<decltype(args)>(args)...);
|
| 110 |
+
};
|
| 111 |
+
}
|
| 112 |
+
|
| 113 |
+
} // namespace at
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/TypeDefault.h
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/Dimname.h>
|
| 4 |
+
#include <c10/core/MemoryFormat.h>
|
| 5 |
+
#include <c10/core/QScheme.h>
|
| 6 |
+
#include <c10/core/Scalar.h>
|
| 7 |
+
#include <c10/core/TensorOptions.h>
|
| 8 |
+
#include <c10/macros/Export.h>
|
| 9 |
+
#include <c10/util/ArrayRef.h>
|
| 10 |
+
#include <c10/util/intrusive_ptr.h>
|
| 11 |
+
|
| 12 |
+
namespace c10 {
|
| 13 |
+
struct Storage;
|
| 14 |
+
}
|
| 15 |
+
|
| 16 |
+
namespace at {
|
| 17 |
+
|
| 18 |
+
class Tensor;
|
| 19 |
+
using TensorList = ArrayRef<Tensor>;
|
| 20 |
+
|
| 21 |
+
class Context;
|
| 22 |
+
struct Generator;
|
| 23 |
+
|
| 24 |
+
struct Quantizer;
|
| 25 |
+
// This is temporary typedef to enable Quantizer in aten native function API
|
| 26 |
+
// we'll remove them when we are actually exposing Quantizer class
|
| 27 |
+
// to frontend
|
| 28 |
+
using ConstQuantizerPtr = const c10::intrusive_ptr<Quantizer>&;
|
| 29 |
+
|
| 30 |
+
} // namespace at
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/Version.h
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <ATen/Context.h>
|
| 2 |
+
|
| 3 |
+
namespace at {
|
| 4 |
+
|
| 5 |
+
/// Returns a detailed string describing the configuration PyTorch.
|
| 6 |
+
TORCH_API std::string show_config();
|
| 7 |
+
|
| 8 |
+
TORCH_API std::string get_mkl_version();
|
| 9 |
+
|
| 10 |
+
TORCH_API std::string get_mkldnn_version();
|
| 11 |
+
|
| 12 |
+
TORCH_API std::string get_openmp_version();
|
| 13 |
+
|
| 14 |
+
TORCH_API std::string get_cxx_flags();
|
| 15 |
+
|
| 16 |
+
TORCH_API std::string get_cpu_capability();
|
| 17 |
+
|
| 18 |
+
} // namespace at
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/Atomic.cuh
ADDED
|
@@ -0,0 +1,508 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <cuda.h>
|
| 4 |
+
#include <c10/util/Half.h>
|
| 5 |
+
#include <c10/util/BFloat16.h>
|
| 6 |
+
|
| 7 |
+
#include <ATen/NumericUtils.h>
|
| 8 |
+
|
| 9 |
+
#if !(defined(USE_ROCM) || ((defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800))))
|
| 10 |
+
#include <cuda_bf16.h>
|
| 11 |
+
#endif
|
| 12 |
+
|
| 13 |
+
template <typename T>
|
| 14 |
+
struct AtomicFPOp;
|
| 15 |
+
|
| 16 |
+
template <>
|
| 17 |
+
struct AtomicFPOp<at::Half> {
|
| 18 |
+
template <typename func_t>
|
| 19 |
+
inline __device__ at::Half operator() (at::Half *address, at::Half val, const func_t& func) {
|
| 20 |
+
unsigned int * address_as_ui =
|
| 21 |
+
(unsigned int *) ((char *)address - ((size_t)address & 2));
|
| 22 |
+
unsigned int old = *address_as_ui;
|
| 23 |
+
unsigned int assumed;
|
| 24 |
+
|
| 25 |
+
at::Half hsum;
|
| 26 |
+
do {
|
| 27 |
+
assumed = old;
|
| 28 |
+
hsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff);
|
| 29 |
+
hsum = func(hsum, val);
|
| 30 |
+
old = (size_t)address & 2 ? (old & 0xffff) | (hsum.x << 16) : (old & 0xffff0000) | hsum.x;
|
| 31 |
+
old = atomicCAS(address_as_ui, assumed, old);
|
| 32 |
+
} while (assumed != old);
|
| 33 |
+
hsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff);
|
| 34 |
+
return hsum;
|
| 35 |
+
}
|
| 36 |
+
};
|
| 37 |
+
|
| 38 |
+
template <>
|
| 39 |
+
struct AtomicFPOp<at::BFloat16> {
|
| 40 |
+
template <typename func_t>
|
| 41 |
+
inline __device__ at::BFloat16 operator() (at::BFloat16 *address, at::BFloat16 val, const func_t& func) {
|
| 42 |
+
unsigned int * address_as_ui =
|
| 43 |
+
(unsigned int *) ((char *)address - ((size_t)address & 2));
|
| 44 |
+
unsigned int old = *address_as_ui;
|
| 45 |
+
unsigned int assumed;
|
| 46 |
+
|
| 47 |
+
at::BFloat16 bsum;
|
| 48 |
+
do {
|
| 49 |
+
assumed = old;
|
| 50 |
+
bsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff);
|
| 51 |
+
bsum = func(bsum, val);
|
| 52 |
+
old = (size_t)address & 2 ? (old & 0xffff) | (bsum.x << 16) : (old & 0xffff0000) | bsum.x;
|
| 53 |
+
old = atomicCAS(address_as_ui, assumed, old);
|
| 54 |
+
} while (assumed != old);
|
| 55 |
+
bsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff);
|
| 56 |
+
return bsum.x;
|
| 57 |
+
}
|
| 58 |
+
};
|
| 59 |
+
|
| 60 |
+
template <>
|
| 61 |
+
struct AtomicFPOp<double> {
|
| 62 |
+
template <typename func_t>
|
| 63 |
+
inline __device__ double operator() (double * address, double val, const func_t& func) {
|
| 64 |
+
unsigned long long int* address_as_ull = (unsigned long long int*)address;
|
| 65 |
+
unsigned long long int old = *address_as_ull;
|
| 66 |
+
unsigned long long int assumed;
|
| 67 |
+
|
| 68 |
+
do {
|
| 69 |
+
assumed = old;
|
| 70 |
+
old = atomicCAS(address_as_ull, assumed, func(val, assumed));
|
| 71 |
+
// Note: uses integer comparison to avoid hang in case of NaN (since NaN != NaN)
|
| 72 |
+
} while (assumed != old);
|
| 73 |
+
|
| 74 |
+
return __longlong_as_double(old);
|
| 75 |
+
}
|
| 76 |
+
};
|
| 77 |
+
|
| 78 |
+
#define ATOMIC_INTEGER_IMPL(NAME) \
|
| 79 |
+
template <typename T, size_t n> \
|
| 80 |
+
struct Atomic##NAME##IntegerImpl; \
|
| 81 |
+
\
|
| 82 |
+
template<typename T> \
|
| 83 |
+
struct Atomic##NAME##IntegerImpl<T, 1> { \
|
| 84 |
+
template <typename func_t> \
|
| 85 |
+
inline __device__ void operator()(T *address, T val, const func_t& func) { \
|
| 86 |
+
size_t offset = (size_t)address & 3; \
|
| 87 |
+
uint32_t * address_as_ui = (uint32_t *)((char *)address - offset); \
|
| 88 |
+
uint32_t old = *address_as_ui; \
|
| 89 |
+
uint32_t shift = offset * 8; \
|
| 90 |
+
uint32_t old_byte; \
|
| 91 |
+
uint32_t newval; \
|
| 92 |
+
uint32_t assumed; \
|
| 93 |
+
\
|
| 94 |
+
do { \
|
| 95 |
+
assumed = old; \
|
| 96 |
+
old_byte = (old >> shift) & 0xff; \
|
| 97 |
+
newval = static_cast<uint8_t>(func(val, static_cast<T>(old_byte))); \
|
| 98 |
+
newval = (old & ~(0x000000ff << shift)) | (newval << shift); \
|
| 99 |
+
old = atomicCAS(address_as_ui, assumed, newval); \
|
| 100 |
+
} while (assumed != old); \
|
| 101 |
+
} \
|
| 102 |
+
}; \
|
| 103 |
+
\
|
| 104 |
+
template<typename T> \
|
| 105 |
+
struct Atomic##NAME##IntegerImpl<T, 2> { \
|
| 106 |
+
template <typename func_t> \
|
| 107 |
+
inline __device__ void operator()(T *address, T val, const func_t& func) { \
|
| 108 |
+
size_t offset = (size_t)address & 2; \
|
| 109 |
+
uint32_t * address_as_ui = (uint32_t *)((char *)address - offset); \
|
| 110 |
+
bool is_32_align = offset; \
|
| 111 |
+
uint32_t old = *address_as_ui; \
|
| 112 |
+
uint32_t old_bytes; \
|
| 113 |
+
uint32_t newval; \
|
| 114 |
+
uint32_t assumed; \
|
| 115 |
+
\
|
| 116 |
+
do { \
|
| 117 |
+
assumed = old; \
|
| 118 |
+
old_bytes = is_32_align ? old >> 16 : old & 0xffff; \
|
| 119 |
+
newval = static_cast<uint16_t>(func(val, static_cast<T>(old_bytes))); \
|
| 120 |
+
newval = is_32_align ? (old & 0xffff) | (newval << 16) : (old & 0xffff0000) | newval; \
|
| 121 |
+
old = atomicCAS(address_as_ui, assumed, newval); \
|
| 122 |
+
} while (assumed != old); \
|
| 123 |
+
} \
|
| 124 |
+
}; \
|
| 125 |
+
\
|
| 126 |
+
template<typename T> \
|
| 127 |
+
struct Atomic##NAME##IntegerImpl<T, 4> { \
|
| 128 |
+
template <typename func_t> \
|
| 129 |
+
inline __device__ void operator()(T *address, T val, const func_t& func) { \
|
| 130 |
+
uint32_t * address_as_ui = (uint32_t *) (address); \
|
| 131 |
+
uint32_t old = *address_as_ui; \
|
| 132 |
+
uint32_t newval; \
|
| 133 |
+
uint32_t assumed; \
|
| 134 |
+
\
|
| 135 |
+
do { \
|
| 136 |
+
assumed = old; \
|
| 137 |
+
newval = static_cast<uint32_t>(func(val, static_cast<T>(old))); \
|
| 138 |
+
old = atomicCAS(address_as_ui, assumed, newval); \
|
| 139 |
+
} while (assumed != old); \
|
| 140 |
+
} \
|
| 141 |
+
}; \
|
| 142 |
+
\
|
| 143 |
+
template<typename T> \
|
| 144 |
+
struct Atomic##NAME##IntegerImpl<T, 8> { \
|
| 145 |
+
template <typename func_t> \
|
| 146 |
+
inline __device__ void operator()(T *address, T val, const func_t& func) { \
|
| 147 |
+
unsigned long long * address_as_ui = (unsigned long long *) (address); \
|
| 148 |
+
unsigned long long old = *address_as_ui; \
|
| 149 |
+
unsigned long long newval; \
|
| 150 |
+
unsigned long long assumed; \
|
| 151 |
+
\
|
| 152 |
+
do { \
|
| 153 |
+
assumed = old; \
|
| 154 |
+
newval = static_cast<uint64_t>(func(val, static_cast<T>(old))); \
|
| 155 |
+
old = atomicCAS(address_as_ui, assumed, newval); \
|
| 156 |
+
} while (assumed != old); \
|
| 157 |
+
} \
|
| 158 |
+
};
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
# define GPU_ATOMIC_INTEGER(NAME, OP, DTYPE) \
|
| 162 |
+
static inline __device__ void gpuAtomic##NAME(DTYPE *address, DTYPE val) { \
|
| 163 |
+
Atomic##NAME##IntegerImpl<DTYPE, sizeof(DTYPE)>()(address, \
|
| 164 |
+
val, \
|
| 165 |
+
[](DTYPE a, DTYPE b) { \
|
| 166 |
+
return OP; \
|
| 167 |
+
}); \
|
| 168 |
+
} \
|
| 169 |
+
|
| 170 |
+
ATOMIC_INTEGER_IMPL(Add)
|
| 171 |
+
GPU_ATOMIC_INTEGER(Add, a || b, bool)
|
| 172 |
+
|
| 173 |
+
// Don't instantiate gpuAtomicAdd with the macro as it seems non-standard (see int32, int64)
|
| 174 |
+
static inline __device__ void gpuAtomicAdd(uint8_t *address, uint8_t val) {
|
| 175 |
+
AtomicAddIntegerImpl<uint8_t, sizeof(uint8_t)>()(address,
|
| 176 |
+
val,
|
| 177 |
+
[](uint8_t a, uint8_t b) {
|
| 178 |
+
return a + b;
|
| 179 |
+
});
|
| 180 |
+
}
|
| 181 |
+
|
| 182 |
+
static inline __device__ void gpuAtomicAdd(int8_t *address, int8_t val) {
|
| 183 |
+
AtomicAddIntegerImpl<int8_t, sizeof(int8_t)>()(address,
|
| 184 |
+
val,
|
| 185 |
+
[](int8_t a, int8_t b) {
|
| 186 |
+
return a + b;
|
| 187 |
+
});
|
| 188 |
+
}
|
| 189 |
+
|
| 190 |
+
static inline __device__ void gpuAtomicAdd(int16_t *address, int16_t val) {
|
| 191 |
+
AtomicAddIntegerImpl<int16_t, sizeof(int16_t)>()(address,
|
| 192 |
+
val,
|
| 193 |
+
[](int16_t a, int16_t b) {
|
| 194 |
+
return a + b;
|
| 195 |
+
});
|
| 196 |
+
}
|
| 197 |
+
|
| 198 |
+
static inline __device__ int32_t gpuAtomicAdd(int32_t *address, int32_t val) {
|
| 199 |
+
return atomicAdd(address, val);
|
| 200 |
+
}
|
| 201 |
+
|
| 202 |
+
static inline __device__ void gpuAtomicAdd(int64_t *address, int64_t val) {
|
| 203 |
+
#if defined(USE_ROCM)
|
| 204 |
+
__atomic_fetch_add(address, val, __ATOMIC_RELAXED);
|
| 205 |
+
#else
|
| 206 |
+
static_assert(sizeof(unsigned long long int) == sizeof(int64_t), "bitwidth change is not allowed");
|
| 207 |
+
atomicAdd(reinterpret_cast<unsigned long long int *>(address), static_cast<unsigned long long int>(val));
|
| 208 |
+
#endif
|
| 209 |
+
}
|
| 210 |
+
|
| 211 |
+
static inline __device__ at::Half gpuAtomicAdd(at::Half *address, at::Half val) {
|
| 212 |
+
#if defined(USE_ROCM) || ((defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 700)))
|
| 213 |
+
return AtomicFPOp<at::Half>()(address, val,
|
| 214 |
+
[](at::Half hsum, at::Half val) {
|
| 215 |
+
return hsum + val;
|
| 216 |
+
});
|
| 217 |
+
#else
|
| 218 |
+
return atomicAdd(reinterpret_cast<__half*>(address), val);
|
| 219 |
+
#endif
|
| 220 |
+
}
|
| 221 |
+
|
| 222 |
+
static inline __device__ at::BFloat16 gpuAtomicAdd(at::BFloat16 *address, at::BFloat16 val) {
|
| 223 |
+
#if defined(USE_ROCM) || ((defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800)))
|
| 224 |
+
return AtomicFPOp<at::BFloat16>()(address, val,
|
| 225 |
+
[](at::BFloat16 bsum, at::BFloat16 val) {
|
| 226 |
+
return bsum + val;
|
| 227 |
+
});
|
| 228 |
+
#else
|
| 229 |
+
__nv_bfloat16 r = atomicAdd(reinterpret_cast<__nv_bfloat16*>(address), *reinterpret_cast<__nv_bfloat16*>(&val));
|
| 230 |
+
return *reinterpret_cast<c10::BFloat16*>(&r);
|
| 231 |
+
#endif
|
| 232 |
+
}
|
| 233 |
+
|
| 234 |
+
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 600)
|
| 235 |
+
// from CUDA C Programmic Guide
|
| 236 |
+
static inline __device__ double atomicAdd(double* address, double val)
|
| 237 |
+
#if defined(__clang__) && defined(__CUDA__)
|
| 238 |
+
#pragma GCC diagnostic push
|
| 239 |
+
#pragma GCC diagnostic ignored "-Wgcc-compat"
|
| 240 |
+
__attribute__((enable_if(true, "")))
|
| 241 |
+
#pragma GCC diagnostic pop
|
| 242 |
+
#endif
|
| 243 |
+
{
|
| 244 |
+
|
| 245 |
+
return AtomicFPOp<double>()(address, val,
|
| 246 |
+
[](double val, unsigned long long int assumed) {
|
| 247 |
+
return __double_as_longlong(val + __longlong_as_double(assumed));
|
| 248 |
+
});
|
| 249 |
+
}
|
| 250 |
+
#elif defined(USE_ROCM) || !(defined(__CUDA_ARCH__))
|
| 251 |
+
|
| 252 |
+
/* Note [hip-clang differences to hcc]
|
| 253 |
+
* ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 254 |
+
* The upcoming hip-clang compiler for ROCm differs from hcc in a few details.
|
| 255 |
+
* It exports the __HIP__ macro, we can hence differentiate between hcc and
|
| 256 |
+
* hip-clang. In the below, hcc only received support for atomicAdd with double
|
| 257 |
+
* typing after work week 18312. hip-clang had support from the first version.
|
| 258 |
+
* In general, the code-visible differences between hip-clang and hcc will be
|
| 259 |
+
* minimal.
|
| 260 |
+
*/
|
| 261 |
+
|
| 262 |
+
#if defined(USE_ROCM) && __hcc_workweek__ < 18312 && !__HIP__
|
| 263 |
+
// This needs to be defined for the host side pass
|
| 264 |
+
static inline __device__ double atomicAdd(double *address, double val) { }
|
| 265 |
+
#endif
|
| 266 |
+
#endif
|
| 267 |
+
|
| 268 |
+
static inline __device__ double gpuAtomicAdd(double *address, double val) {
|
| 269 |
+
return atomicAdd(address, val);
|
| 270 |
+
}
|
| 271 |
+
|
| 272 |
+
static inline __device__ float gpuAtomicAdd(float *address, float val) {
|
| 273 |
+
return atomicAdd(address, val);
|
| 274 |
+
}
|
| 275 |
+
|
| 276 |
+
template<typename T>
|
| 277 |
+
static inline __device__ void gpuAtomicAdd(c10::complex<T> *address, c10::complex<T> val) {
|
| 278 |
+
gpuAtomicAdd(&address->real_, val.real_);
|
| 279 |
+
gpuAtomicAdd(&address->imag_, val.imag_);
|
| 280 |
+
}
|
| 281 |
+
|
| 282 |
+
/* Note [gpuAtomicAdd vs atomicAdd]
|
| 283 |
+
* ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 284 |
+
* Some extensions such as torchvision call atomicAdd()
|
| 285 |
+
* directly and require non-library provided data type support. Only for these, we
|
| 286 |
+
* continue to provide atomicAdd overloads.
|
| 287 |
+
*/
|
| 288 |
+
static inline __device__ at::Half atomicAdd(at::Half *address, at::Half val) {
|
| 289 |
+
return gpuAtomicAdd(address, val);
|
| 290 |
+
}
|
| 291 |
+
|
| 292 |
+
static inline __device__ at::BFloat16 atomicAdd(at::BFloat16 *address, at::BFloat16 val) {
|
| 293 |
+
return gpuAtomicAdd(address, val);
|
| 294 |
+
}
|
| 295 |
+
|
| 296 |
+
static inline __device__ void atomicAdd(uint8_t *address, uint8_t val) {
|
| 297 |
+
gpuAtomicAdd(address, val);
|
| 298 |
+
}
|
| 299 |
+
|
| 300 |
+
static inline __device__ void atomicAdd(int8_t *address, int8_t val) {
|
| 301 |
+
gpuAtomicAdd(address, val);
|
| 302 |
+
}
|
| 303 |
+
|
| 304 |
+
static inline __device__ void atomicAdd(int16_t *address, int16_t val) {
|
| 305 |
+
gpuAtomicAdd(address, val);
|
| 306 |
+
}
|
| 307 |
+
|
| 308 |
+
static inline __device__ void atomicAdd(int64_t *address, int64_t val) {
|
| 309 |
+
gpuAtomicAdd(address, val);
|
| 310 |
+
}
|
| 311 |
+
|
| 312 |
+
static inline __device__ void atomicAdd(bool *address, bool val) {
|
| 313 |
+
gpuAtomicAdd(address, val);
|
| 314 |
+
}
|
| 315 |
+
|
| 316 |
+
/* Note [explicitly non-returning atomics]
|
| 317 |
+
* ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 318 |
+
* AMD's MI100 (gfx908) provides an optimized fp32 atomicAdd, exposed via atomicAddNoRet().
|
| 319 |
+
* Due to compiler limitations, callers must opt-in to guarantee the optimized instruction.
|
| 320 |
+
* This non-returning atomicAddNoRet cannot be used to implement the returning atomicAdd,
|
| 321 |
+
* therefore we need a new API 'gpuAtomicAddNoReturn'.
|
| 322 |
+
*/
|
| 323 |
+
template<typename T>
|
| 324 |
+
static inline __device__ void gpuAtomicAddNoReturn(c10::complex<T> *address, c10::complex<T> val) { gpuAtomicAdd(address, val); }
|
| 325 |
+
static inline __device__ void gpuAtomicAddNoReturn(uint8_t *address, uint8_t val) { gpuAtomicAdd(address, val); }
|
| 326 |
+
static inline __device__ void gpuAtomicAddNoReturn(int8_t *address, int8_t val) { gpuAtomicAdd(address, val); }
|
| 327 |
+
static inline __device__ void gpuAtomicAddNoReturn(int16_t *address, int16_t val) { gpuAtomicAdd(address, val); }
|
| 328 |
+
static inline __device__ void gpuAtomicAddNoReturn(int32_t *address, int32_t val) { gpuAtomicAdd(address, val); }
|
| 329 |
+
static inline __device__ void gpuAtomicAddNoReturn(int64_t *address, int64_t val) { gpuAtomicAdd(address, val); }
|
| 330 |
+
static inline __device__ void gpuAtomicAddNoReturn(bool *address, bool val) { gpuAtomicAdd(address, val); }
|
| 331 |
+
static inline __device__ void gpuAtomicAddNoReturn(at::Half *address, at::Half val) { gpuAtomicAdd(address, val); }
|
| 332 |
+
static inline __device__ void gpuAtomicAddNoReturn(at::BFloat16 *address, at::BFloat16 val) { gpuAtomicAdd(address, val); }
|
| 333 |
+
static inline __device__ void gpuAtomicAddNoReturn(double *address, double val) { gpuAtomicAdd(address, val); }
|
| 334 |
+
|
| 335 |
+
/* Special case fp32 atomic. */
|
| 336 |
+
#if defined(USE_ROCM)
|
| 337 |
+
static inline __device__ void gpuAtomicAddNoReturn(float *address, float val) { atomicAddNoRet(address, val); }
|
| 338 |
+
#else
|
| 339 |
+
static inline __device__ void gpuAtomicAddNoReturn(float *address, float val) { gpuAtomicAdd(address, val); }
|
| 340 |
+
#endif
|
| 341 |
+
|
| 342 |
+
// Atomic multiplication implementation.
|
| 343 |
+
|
| 344 |
+
ATOMIC_INTEGER_IMPL(Mul)
|
| 345 |
+
GPU_ATOMIC_INTEGER(Mul, a * b, uint8_t)
|
| 346 |
+
GPU_ATOMIC_INTEGER(Mul, a * b, int8_t)
|
| 347 |
+
GPU_ATOMIC_INTEGER(Mul, a * b, int16_t)
|
| 348 |
+
GPU_ATOMIC_INTEGER(Mul, a * b, int32_t)
|
| 349 |
+
GPU_ATOMIC_INTEGER(Mul, a * b, int64_t)
|
| 350 |
+
|
| 351 |
+
inline __device__ at::Half gpuAtomicMul(at::Half * address, at::Half val) {
|
| 352 |
+
return AtomicFPOp<at::Half>()(address, val,
|
| 353 |
+
[](at::Half bsum, at::Half val) {
|
| 354 |
+
return bsum * val;
|
| 355 |
+
});
|
| 356 |
+
}
|
| 357 |
+
|
| 358 |
+
inline __device__ at::BFloat16 gpuAtomicMul(at::BFloat16 * address, at::BFloat16 val) {
|
| 359 |
+
return AtomicFPOp<at::BFloat16>()(address, val,
|
| 360 |
+
[](at::BFloat16 bsum, at::BFloat16 val) {
|
| 361 |
+
return bsum * val;
|
| 362 |
+
});
|
| 363 |
+
}
|
| 364 |
+
|
| 365 |
+
inline __device__ double gpuAtomicMul(double * address, double val) {
|
| 366 |
+
return AtomicFPOp<double>()(address, val,
|
| 367 |
+
[](double val, unsigned long long int assumed) {
|
| 368 |
+
return __double_as_longlong(val * __longlong_as_double(assumed));
|
| 369 |
+
});
|
| 370 |
+
}
|
| 371 |
+
|
| 372 |
+
// Dont use a templated function for this since the addition function defaults to the CUDA built-in.
|
| 373 |
+
inline __device__ float gpuAtomicMul (float * address, float val) {
|
| 374 |
+
unsigned int* address_as_ull = (unsigned int*)address;
|
| 375 |
+
unsigned int old = *address_as_ull;
|
| 376 |
+
unsigned int assumed;
|
| 377 |
+
|
| 378 |
+
do {
|
| 379 |
+
assumed = old;
|
| 380 |
+
old = atomicCAS(address_as_ull, assumed,
|
| 381 |
+
__float_as_int(val *
|
| 382 |
+
__int_as_float(assumed)));
|
| 383 |
+
|
| 384 |
+
// Note: uses integer comparison to avoid hang in case of NaN (since NaN != NaN)
|
| 385 |
+
} while (assumed != old);
|
| 386 |
+
|
| 387 |
+
return __int_as_float(old);
|
| 388 |
+
}
|
| 389 |
+
|
| 390 |
+
// Atomic maximum implementation.
|
| 391 |
+
|
| 392 |
+
template <typename T>
|
| 393 |
+
__host__ __device__ T safe_max(T a, T b) {
|
| 394 |
+
#if defined(__HIPCC__)
|
| 395 |
+
// TODO: remove this special case for HIP when issue is fixed:
|
| 396 |
+
// https://github.com/ROCm-Developer-Tools/HIP/issues/2209
|
| 397 |
+
T max = at::_isnan(a) ? a : (at::_isnan(b) ? b : std::max<T>(a, b));
|
| 398 |
+
#else
|
| 399 |
+
T max = at::_isnan(b) ? b : std::max<T>(a, b);
|
| 400 |
+
#endif
|
| 401 |
+
|
| 402 |
+
return max;
|
| 403 |
+
}
|
| 404 |
+
|
| 405 |
+
ATOMIC_INTEGER_IMPL(Max)
|
| 406 |
+
GPU_ATOMIC_INTEGER(Max, safe_max(a, b), uint8_t)
|
| 407 |
+
GPU_ATOMIC_INTEGER(Max, safe_max(a, b), int8_t)
|
| 408 |
+
GPU_ATOMIC_INTEGER(Max, safe_max(a, b), int16_t)
|
| 409 |
+
GPU_ATOMIC_INTEGER(Max, safe_max(a, b), int32_t)
|
| 410 |
+
GPU_ATOMIC_INTEGER(Max, safe_max(a, b), int64_t)
|
| 411 |
+
|
| 412 |
+
inline __device__ at::Half gpuAtomicMax(at::Half * address, at::Half val) {
|
| 413 |
+
return AtomicFPOp<at::Half>()(address, val,
|
| 414 |
+
[](at::Half bsum, at::Half val) {
|
| 415 |
+
return safe_max(bsum, val);
|
| 416 |
+
});
|
| 417 |
+
}
|
| 418 |
+
|
| 419 |
+
inline __device__ at::BFloat16 gpuAtomicMax(at::BFloat16 * address, at::BFloat16 val) {
|
| 420 |
+
return AtomicFPOp<at::BFloat16>()(address, val,
|
| 421 |
+
[](at::BFloat16 bsum, at::BFloat16 val) {
|
| 422 |
+
return safe_max(bsum, val);
|
| 423 |
+
});
|
| 424 |
+
}
|
| 425 |
+
|
| 426 |
+
inline __device__ double gpuAtomicMax(double * address, double val) {
|
| 427 |
+
return AtomicFPOp<double>()(address, val,
|
| 428 |
+
[](double val, unsigned long long int assumed) {
|
| 429 |
+
return __double_as_longlong(safe_max(val, __longlong_as_double(assumed)));
|
| 430 |
+
});
|
| 431 |
+
}
|
| 432 |
+
|
| 433 |
+
// Dont use a templated function for this since the addition function defaults to the CUDA built-in.
|
| 434 |
+
inline __device__ float gpuAtomicMax(float * address, float val) {
|
| 435 |
+
unsigned int* address_as_ull = (unsigned int*)address;
|
| 436 |
+
unsigned int old = *address_as_ull;
|
| 437 |
+
unsigned int assumed;
|
| 438 |
+
|
| 439 |
+
do {
|
| 440 |
+
assumed = old;
|
| 441 |
+
old = atomicCAS(address_as_ull, assumed,
|
| 442 |
+
__float_as_int(safe_max(val, __int_as_float(assumed))));
|
| 443 |
+
|
| 444 |
+
// Note: uses integer comparison to avoid hang in case of NaN (since NaN != NaN)
|
| 445 |
+
} while (assumed != old);
|
| 446 |
+
|
| 447 |
+
return __int_as_float(old);
|
| 448 |
+
}
|
| 449 |
+
|
| 450 |
+
// Atomic minimum implementation.
|
| 451 |
+
|
| 452 |
+
template <typename T>
|
| 453 |
+
__host__ __device__ T safe_min(T a, T b) {
|
| 454 |
+
#if defined(__HIPCC__)
|
| 455 |
+
// TODO: remove this special case for HIP when issue is fixed:
|
| 456 |
+
// https://github.com/ROCm-Developer-Tools/HIP/issues/2209
|
| 457 |
+
T min = at::_isnan(a) ? a : (at::_isnan(b) ? b : std::min<T>(a, b));
|
| 458 |
+
#else
|
| 459 |
+
T min = at::_isnan(b) ? b : std::min<T>(a, b);
|
| 460 |
+
#endif
|
| 461 |
+
|
| 462 |
+
return min;
|
| 463 |
+
}
|
| 464 |
+
|
| 465 |
+
ATOMIC_INTEGER_IMPL(Min)
|
| 466 |
+
GPU_ATOMIC_INTEGER(Min, safe_min(a, b), uint8_t)
|
| 467 |
+
GPU_ATOMIC_INTEGER(Min, safe_min(a, b), int8_t)
|
| 468 |
+
GPU_ATOMIC_INTEGER(Min, safe_min(a, b), int16_t)
|
| 469 |
+
GPU_ATOMIC_INTEGER(Min, safe_min(a, b), int32_t)
|
| 470 |
+
GPU_ATOMIC_INTEGER(Min, safe_min(a, b), int64_t)
|
| 471 |
+
|
| 472 |
+
inline __device__ at::Half gpuAtomicMin(at::Half * address, at::Half val) {
|
| 473 |
+
return AtomicFPOp<at::Half>()(address, val,
|
| 474 |
+
[](at::Half bsum, at::Half val) {
|
| 475 |
+
return safe_min(bsum, val);
|
| 476 |
+
});
|
| 477 |
+
}
|
| 478 |
+
|
| 479 |
+
inline __device__ at::BFloat16 gpuAtomicMin(at::BFloat16 * address, at::BFloat16 val) {
|
| 480 |
+
return AtomicFPOp<at::BFloat16>()(address, val,
|
| 481 |
+
[](at::BFloat16 bsum, at::BFloat16 val) {
|
| 482 |
+
return safe_min(bsum, val);
|
| 483 |
+
});
|
| 484 |
+
}
|
| 485 |
+
|
| 486 |
+
inline __device__ double gpuAtomicMin(double * address, double val) {
|
| 487 |
+
return AtomicFPOp<double>()(address, val,
|
| 488 |
+
[](double val, unsigned long long int assumed) {
|
| 489 |
+
return __double_as_longlong(safe_min(val, __longlong_as_double(assumed)));
|
| 490 |
+
});
|
| 491 |
+
}
|
| 492 |
+
|
| 493 |
+
// Dont use a templated function for this since the addition function defaults to the CUDA built-in.
|
| 494 |
+
inline __device__ float gpuAtomicMin(float * address, float val) {
|
| 495 |
+
unsigned int* address_as_ull = (unsigned int*)address;
|
| 496 |
+
unsigned int old = *address_as_ull;
|
| 497 |
+
unsigned int assumed;
|
| 498 |
+
|
| 499 |
+
do {
|
| 500 |
+
assumed = old;
|
| 501 |
+
old = atomicCAS(address_as_ull, assumed,
|
| 502 |
+
__float_as_int(safe_min(val, __int_as_float(assumed))));
|
| 503 |
+
|
| 504 |
+
// Note: uses integer comparison to avoid hang in case of NaN (since NaN != NaN)
|
| 505 |
+
} while (assumed != old);
|
| 506 |
+
|
| 507 |
+
return __int_as_float(old);
|
| 508 |
+
}
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/CUDAApplyUtils.cuh
ADDED
|
@@ -0,0 +1,537 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/cuda/ApplyGridUtils.cuh>
|
| 4 |
+
#include <ATen/cuda/detail/IndexUtils.cuh>
|
| 5 |
+
#include <ATen/core/TensorBase.h>
|
| 6 |
+
#include <ATen/ceil_div.h>
|
| 7 |
+
#include <ATen/cuda/Atomic.cuh>
|
| 8 |
+
#include <ATen/cuda/CUDAContext.h>
|
| 9 |
+
#include <c10/macros/Macros.h>
|
| 10 |
+
#include <ATen/native/Copy.h>
|
| 11 |
+
|
| 12 |
+
#include <math.h>
|
| 13 |
+
|
| 14 |
+
//
|
| 15 |
+
// This file contains pointwise operation functions and kernels that
|
| 16 |
+
// work on both contiguous and non-contiguous tensor arguments of
|
| 17 |
+
// arbitrary (up to MAX_CUTORCH_DIMS) dimensioned arguments without
|
| 18 |
+
// copying or temporary storage.
|
| 19 |
+
//
|
| 20 |
+
|
| 21 |
+
/*
|
| 22 |
+
NOTE [ CUDA_tensor_applyN helpers ]
|
| 23 |
+
|
| 24 |
+
The following CUDA_tensor_applyN (where N currently can be 1, 2, 3, or 4)
|
| 25 |
+
functions apply a pointwise operator to N tensor(s).
|
| 26 |
+
|
| 27 |
+
The calling convention is
|
| 28 |
+
|
| 29 |
+
1. The template arguments should be, sequentially,
|
| 30 |
+
- First N typename args specify the scalar types of each of the N tensors.
|
| 31 |
+
- (Optional) `int step` arg specifies the number of elements processed
|
| 32 |
+
together at the same time.
|
| 33 |
+
Default is 1.
|
| 34 |
+
- A usually omitted (i.e., inferred) typename arg specifies the type of the
|
| 35 |
+
function/functor applied on `N * step` values in each iteration of each
|
| 36 |
+
CUDA thread.
|
| 37 |
+
2. The arguments should be, sequentially,
|
| 38 |
+
- N tensors
|
| 39 |
+
- op: a function/functor that processes `N * step` values at the same time.
|
| 40 |
+
- If `step == 1`, it must have signature
|
| 41 |
+
`void(*)(scalar1_t&, scalar2_t&, ..., scalarN_t&)`, where
|
| 42 |
+
`scalar*_t`s are the first N typename template args, and the inputs
|
| 43 |
+
are the `N` values from the `N` tensors retrieved at a common index.
|
| 44 |
+
- Otherwise, it must must have signature
|
| 45 |
+
void(*)(int n, scalar1_t&, scalar1_t&, ..., scalar1_t&, // repeat `step` times
|
| 46 |
+
scalar2_t&, scalar2_t&, ..., scalar2_t&, // repeat `step` times
|
| 47 |
+
...,
|
| 48 |
+
scalarN_t&, scalarN_t&, ..., scalarN_t&) // repeat `step` times
|
| 49 |
+
Different from `step == 1` case, it processes `N * step` values taken
|
| 50 |
+
from `step` common indices. Moreover, the first input `n` represents the
|
| 51 |
+
number of valid indices (it will always have `0 < n <= step`). It will
|
| 52 |
+
almost always be `step`, but at the boundary we may not have full `step`
|
| 53 |
+
elements and `n` can be a lesser value.
|
| 54 |
+
|
| 55 |
+
E.g., if `step == 4` and `N == 2`, `op` could be
|
| 56 |
+
|
| 57 |
+
[](int n, scalar1_t &u1, scalar1_t &u2, scalar1_t &u3, scalar1_t &u4,
|
| 58 |
+
scalar2_t &v1, scalar2_t &v2, scalar2_t &v3, scalar2_t &v4) {
|
| 59 |
+
// Only process u1, ..., un and v1, ..., vn.
|
| 60 |
+
// So if `n == 3`, `u4` and `v4` need not to be considered.
|
| 61 |
+
}
|
| 62 |
+
|
| 63 |
+
In both cases, the references can actually be const, but at least one of
|
| 64 |
+
them should be non-const in order to write the output.
|
| 65 |
+
- (Optional, but recommended) N TensorArgType args that specify for each
|
| 66 |
+
tensor whether `op` reads AND writes ] (i.e., TensorArgType::ReadWrite),
|
| 67 |
+
or only reads (i.e., TensorArgType::ReadOnly).
|
| 68 |
+
Default is TensorArgType::ReadWrite for first Tensor, and
|
| 69 |
+
TensorArgType::ReadOnly for the rest.
|
| 70 |
+
|
| 71 |
+
E.g.,
|
| 72 |
+
|
| 73 |
+
to compute a = b^2 for a and b of same dtype, we can call
|
| 74 |
+
|
| 75 |
+
CUDA_tensor_apply2<scalar, scalar>(
|
| 76 |
+
a, b,
|
| 77 |
+
[] __device__ (scalar &a_val, const scalar &b_val) { a_val = b_val * b_val; }
|
| 78 |
+
);
|
| 79 |
+
|
| 80 |
+
to work on 2 values at the same time, we can call
|
| 81 |
+
|
| 82 |
+
CUDA_tensor_apply2<scalar1, scalar2, 2>(
|
| 83 |
+
a, b,
|
| 84 |
+
[] __device__ (int n, scalar1 &a_val1, scalar1 &a_val2,
|
| 85 |
+
const scalar2 &b_val1, const scalar2 &b_val2) {
|
| 86 |
+
// call special vectorized op here, or just do elementwise and enjoy unrolling...
|
| 87 |
+
// if n == 1, only process a_val1 and b_val1
|
| 88 |
+
}
|
| 89 |
+
);
|
| 90 |
+
*/
|
| 91 |
+
|
| 92 |
+
namespace at::cuda {
|
| 93 |
+
|
| 94 |
+
// TODO: combine with TensorArg? So far that's been for debugging, and this is functional...
|
| 95 |
+
enum class TensorArgType { ReadWrite, ReadOnly };
|
| 96 |
+
|
| 97 |
+
namespace {
|
| 98 |
+
|
| 99 |
+
// Rearrange dimensions for pointwise operations so that strides are in
|
| 100 |
+
// decreasing order as much as possible, so that kernels have better memory
|
| 101 |
+
// access patterns.
|
| 102 |
+
//
|
| 103 |
+
// For example, consider a binary operation on two "transposed" 2-dim tensors:
|
| 104 |
+
// sizes: 256 512
|
| 105 |
+
// aInfo->strides: 1 256
|
| 106 |
+
// bInfo->strides: 1 256
|
| 107 |
+
//
|
| 108 |
+
// Given this, each concurrent memory access inside kernelPointwiseApply2() is
|
| 109 |
+
// exactly 256 elements apart, resulting in poor performance.
|
| 110 |
+
//
|
| 111 |
+
// This function exchanges dimensions so that memory access is contiguous:
|
| 112 |
+
// sizes: 512 256
|
| 113 |
+
// aInfo->strides: 256 1
|
| 114 |
+
// bInfo->strides: 256 1
|
| 115 |
+
//
|
| 116 |
+
// (Actually, it becomes even better because now collapseDims() can turn each
|
| 117 |
+
// input into one contiguous array.)
|
| 118 |
+
//
|
| 119 |
+
// In general, given M (<=4) TensorInfo's with N dimensions, we can view each
|
| 120 |
+
// strides[i] (0 <= i < N) as an M-tuple. Given each pair i < j, we exchange
|
| 121 |
+
// strides[i] and [j] if
|
| 122 |
+
// (1) strides[i][k] < strides[j][k] for some k (0 <= k < M)
|
| 123 |
+
// (exchanging them will benefit input #k), and
|
| 124 |
+
// (2) strides[i][k] <= strieds[j][k] for all k
|
| 125 |
+
// (exchanging them will not make any input worse).
|
| 126 |
+
template <typename T1, typename IndexType,
|
| 127 |
+
typename T2 = void, typename T3 = void, typename T4 = void>
|
| 128 |
+
inline void rearrangeDims(detail::TensorInfo<T1, IndexType>* aInfo,
|
| 129 |
+
detail::TensorInfo<T2, IndexType>* bInfo = nullptr,
|
| 130 |
+
detail::TensorInfo<T3, IndexType>* cInfo = nullptr,
|
| 131 |
+
detail::TensorInfo<T4, IndexType>* dInfo = nullptr) {
|
| 132 |
+
int numInfos = 1;
|
| 133 |
+
int dims = aInfo->dims;
|
| 134 |
+
IndexType *sizes[4] = { aInfo->sizes, };
|
| 135 |
+
IndexType *strides[4] = { aInfo->strides, };
|
| 136 |
+
|
| 137 |
+
if (bInfo != nullptr) {
|
| 138 |
+
++numInfos;
|
| 139 |
+
if (bInfo->dims != dims) return;
|
| 140 |
+
sizes[1] = bInfo->sizes;
|
| 141 |
+
strides[1] = bInfo->strides;
|
| 142 |
+
}
|
| 143 |
+
|
| 144 |
+
if (cInfo != nullptr) {
|
| 145 |
+
++numInfos;
|
| 146 |
+
if (cInfo->dims != dims) return;
|
| 147 |
+
sizes[2] = cInfo->sizes;
|
| 148 |
+
strides[2] = cInfo->strides;
|
| 149 |
+
}
|
| 150 |
+
|
| 151 |
+
if (dInfo != nullptr) {
|
| 152 |
+
++numInfos;
|
| 153 |
+
if (dInfo->dims != dims) return;
|
| 154 |
+
sizes[3] = dInfo->sizes;
|
| 155 |
+
strides[3] = dInfo->strides;
|
| 156 |
+
}
|
| 157 |
+
|
| 158 |
+
// Bail out if sizes do not match: we are using "deprecated pointwise
|
| 159 |
+
// behavior" among tensors of different shapes but same number of elements.
|
| 160 |
+
for (int i = 1; i < numInfos; ++i) {
|
| 161 |
+
for (int j = 0; j < dims; ++j) {
|
| 162 |
+
if (sizes[i][j] != sizes[0][j]) return;
|
| 163 |
+
}
|
| 164 |
+
}
|
| 165 |
+
|
| 166 |
+
for (int i = 0; i < dims - 1; ++i) {
|
| 167 |
+
// No need to consider dimensions of size 1.
|
| 168 |
+
if (sizes[0][i] == 1) continue;
|
| 169 |
+
|
| 170 |
+
for (int j = i + 1; j < dims; ++j) {
|
| 171 |
+
if (sizes[0][j] == 1) continue;
|
| 172 |
+
|
| 173 |
+
// Compare the relative sizes of strides between dim #i and dim #j.
|
| 174 |
+
bool hasIncreasingStrides = false;
|
| 175 |
+
bool hasDecreasingStrides = false;
|
| 176 |
+
|
| 177 |
+
for (int k = 0; k < numInfos; k++) {
|
| 178 |
+
IndexType stride_i = strides[k][i];
|
| 179 |
+
IndexType stride_j = strides[k][j];
|
| 180 |
+
if (stride_i < stride_j) {
|
| 181 |
+
hasIncreasingStrides = true;
|
| 182 |
+
} else if (stride_i > stride_j) {
|
| 183 |
+
hasDecreasingStrides = true;
|
| 184 |
+
}
|
| 185 |
+
}
|
| 186 |
+
|
| 187 |
+
if (hasIncreasingStrides && !hasDecreasingStrides) {
|
| 188 |
+
for (int k = 0; k < numInfos; k++) {
|
| 189 |
+
IndexType size = sizes[k][i];
|
| 190 |
+
sizes[k][i] = sizes[k][j];
|
| 191 |
+
sizes[k][j] = size;
|
| 192 |
+
|
| 193 |
+
IndexType stride = strides[k][i];
|
| 194 |
+
strides[k][i] = strides[k][j];
|
| 195 |
+
strides[k][j] = stride;
|
| 196 |
+
}
|
| 197 |
+
}
|
| 198 |
+
}
|
| 199 |
+
}
|
| 200 |
+
}
|
| 201 |
+
|
| 202 |
+
// The `remaining_steps` argument is used to support Op that operates on
|
| 203 |
+
// multiple elements at the same time. Generally, the strategy of ApplyOpN is to
|
| 204 |
+
// 1. Initialize `remaining_steps = step`, where `step` is the template arg of
|
| 205 |
+
// CUDA_tensor_applyN helpers. The input arg `n` to `apply()` represents the
|
| 206 |
+
// number of elements in bound for this call. It will almost always equal to
|
| 207 |
+
// `step` except at boundaries.
|
| 208 |
+
// 2. If `remaining_steps > 0` convert the current linearIndex to offset (if in
|
| 209 |
+
// bound), and recursively call `ApplyOpN` with `remaining_steps - 1`.
|
| 210 |
+
// 3. At `remaining_steps = 0`,
|
| 211 |
+
// if `step = 1`, call `op(tensor1_val, tensor2_val, ...)`;
|
| 212 |
+
// if `step > 1`, call `op(n, tensor1_val1, tensor1_val2, ..., tesor1_valstep,
|
| 213 |
+
// tensor2_val1, tensor2_val2, ..., tesor2_valstep,
|
| 214 |
+
// ...
|
| 215 |
+
// tensorN_val1, tensorN_val2, ..., tesorN_valstep);`
|
| 216 |
+
//
|
| 217 |
+
// See NOTE [ CUDA_tensor_applyN helpers ] above for how Op may look like.
|
| 218 |
+
|
| 219 |
+
template <typename Op,
|
| 220 |
+
typename scalar,
|
| 221 |
+
typename IndexType,
|
| 222 |
+
int ADims,
|
| 223 |
+
int remaining_steps,
|
| 224 |
+
typename... Offsets>
|
| 225 |
+
struct ApplyOp1 {
|
| 226 |
+
__device__ __forceinline__
|
| 227 |
+
static void apply(detail::TensorInfo<scalar, IndexType> &a, const Op &op, int n,
|
| 228 |
+
IndexType linearIndex, Offsets... aOffsets) {
|
| 229 |
+
// Convert `linearIndex` into an offset of `a`
|
| 230 |
+
const IndexType aOffset = sizeof...(Offsets) < n ?
|
| 231 |
+
detail::IndexToOffset<scalar, IndexType, ADims>::get(linearIndex, a) : 0;
|
| 232 |
+
|
| 233 |
+
ApplyOp1<Op, scalar, IndexType, ADims, remaining_steps - 1, const IndexType, Offsets...>::apply(
|
| 234 |
+
a, op, n, linearIndex + 1, aOffsets..., aOffset
|
| 235 |
+
);
|
| 236 |
+
}
|
| 237 |
+
};
|
| 238 |
+
|
| 239 |
+
// Specialize `step=1` case (i.e., `remaining_steps=0` and `len(Offsets)=1`).
|
| 240 |
+
// We don't need to pass in how many elements need to processed in this case.
|
| 241 |
+
template <typename Op,
|
| 242 |
+
typename scalar,
|
| 243 |
+
typename IndexType,
|
| 244 |
+
int ADims,
|
| 245 |
+
typename Offset>
|
| 246 |
+
struct ApplyOp1<Op, scalar, IndexType, ADims, 0, Offset> {
|
| 247 |
+
__device__ __forceinline__
|
| 248 |
+
static void apply(detail::TensorInfo<scalar, IndexType> &a, const Op &op,
|
| 249 |
+
int n, IndexType linearIndex, Offset offset) {
|
| 250 |
+
op(a.data[offset]);
|
| 251 |
+
}
|
| 252 |
+
};
|
| 253 |
+
|
| 254 |
+
template <typename Op,
|
| 255 |
+
typename scalar,
|
| 256 |
+
typename IndexType,
|
| 257 |
+
int ADims,
|
| 258 |
+
typename... Offsets>
|
| 259 |
+
struct ApplyOp1<Op, scalar, IndexType, ADims, 0, Offsets...> {
|
| 260 |
+
__device__ __forceinline__
|
| 261 |
+
static void apply(detail::TensorInfo<scalar, IndexType> &a, const Op &op, int n,
|
| 262 |
+
IndexType linearIndex, Offsets... offsets) {
|
| 263 |
+
op(n, a.data[offsets]...);
|
| 264 |
+
}
|
| 265 |
+
};
|
| 266 |
+
|
| 267 |
+
template <typename Op,
|
| 268 |
+
typename scalar,
|
| 269 |
+
typename IndexType,
|
| 270 |
+
int ADims,
|
| 271 |
+
int step>
|
| 272 |
+
#if __CUDA_ARCH__ >= 350 || defined(USE_ROCM)
|
| 273 |
+
C10_LAUNCH_BOUNDS_2(AT_APPLY_THREADS_PER_BLOCK, AT_APPLY_BLOCKS_PER_SM)
|
| 274 |
+
#endif
|
| 275 |
+
__global__ void kernelPointwiseApply1(detail::TensorInfo<scalar, IndexType> a,
|
| 276 |
+
IndexType totalElements, const Op op) {
|
| 277 |
+
for (IndexType linearIndex = (blockIdx.x * blockDim.x + threadIdx.x) * step;
|
| 278 |
+
linearIndex < totalElements;
|
| 279 |
+
linearIndex += gridDim.x * blockDim.x * step) {
|
| 280 |
+
ApplyOp1<Op, scalar, IndexType, ADims, step>::apply(
|
| 281 |
+
a, op, ::min(step, static_cast<int>(totalElements - linearIndex)), linearIndex);
|
| 282 |
+
}
|
| 283 |
+
}
|
| 284 |
+
|
| 285 |
+
|
| 286 |
+
template <typename Op,
|
| 287 |
+
typename scalar1,
|
| 288 |
+
typename scalar2,
|
| 289 |
+
typename IndexType,
|
| 290 |
+
int ADims,
|
| 291 |
+
int BDims,
|
| 292 |
+
int remaining_steps,
|
| 293 |
+
typename... Offsets>
|
| 294 |
+
struct ApplyOp2 {
|
| 295 |
+
__device__ __forceinline__
|
| 296 |
+
static void apply(detail::TensorInfo<scalar1, IndexType> &a,
|
| 297 |
+
detail::TensorInfo<scalar2, IndexType> &b,
|
| 298 |
+
const Op &op, int64_t n, IndexType linearIndex,
|
| 299 |
+
Offsets... aOffsets, Offsets... bOffsets) {
|
| 300 |
+
// Convert `linearIndex` into an offset of `a`
|
| 301 |
+
const IndexType aOffset = static_cast<int64_t>(sizeof...(Offsets)) < n ?
|
| 302 |
+
detail::IndexToOffset<scalar1, IndexType, ADims>::get(linearIndex, a) : 0;
|
| 303 |
+
|
| 304 |
+
// Convert `linearIndex` into an offset of `b`
|
| 305 |
+
const IndexType bOffset = static_cast<int64_t>(sizeof...(Offsets)) < n ?
|
| 306 |
+
detail::IndexToOffset<scalar2, IndexType, BDims>::get(linearIndex, b) : 0;
|
| 307 |
+
|
| 308 |
+
ApplyOp2<Op, scalar1, scalar2, IndexType, ADims, BDims, remaining_steps - 1, const IndexType, Offsets...>::apply(
|
| 309 |
+
a, b, op, n, linearIndex + 1, aOffsets..., aOffset, bOffsets..., bOffset
|
| 310 |
+
);
|
| 311 |
+
}
|
| 312 |
+
};
|
| 313 |
+
|
| 314 |
+
// Specialize `step=1` case (i.e., `remaining_steps=0` and `len(Offsets)=1`).
|
| 315 |
+
// We don't need to pass in how many elements need to processed in this case.
|
| 316 |
+
template <typename Op,
|
| 317 |
+
typename scalar1,
|
| 318 |
+
typename scalar2,
|
| 319 |
+
typename IndexType,
|
| 320 |
+
int ADims,
|
| 321 |
+
int BDims,
|
| 322 |
+
typename Offset>
|
| 323 |
+
struct ApplyOp2<Op, scalar1, scalar2, IndexType, ADims, BDims, 0, Offset> {
|
| 324 |
+
__device__ __forceinline__
|
| 325 |
+
static void apply(detail::TensorInfo<scalar1, IndexType> &a,
|
| 326 |
+
detail::TensorInfo<scalar2, IndexType> &b,
|
| 327 |
+
const Op &op, int /*n*/, IndexType /*linearIndex*/,
|
| 328 |
+
Offset aOffset, Offset bOffset) {
|
| 329 |
+
op(a.data[aOffset], b.data[bOffset]);
|
| 330 |
+
}
|
| 331 |
+
};
|
| 332 |
+
|
| 333 |
+
template <typename Op,
|
| 334 |
+
typename scalar1,
|
| 335 |
+
typename scalar2,
|
| 336 |
+
typename IndexType,
|
| 337 |
+
int ADims,
|
| 338 |
+
int BDims,
|
| 339 |
+
typename... Offsets>
|
| 340 |
+
struct ApplyOp2<Op, scalar1, scalar2, IndexType, ADims, BDims, 0, Offsets...> {
|
| 341 |
+
__device__ __forceinline__
|
| 342 |
+
static void apply(detail::TensorInfo<scalar1, IndexType> &a,
|
| 343 |
+
detail::TensorInfo<scalar2, IndexType> &b,
|
| 344 |
+
const Op &op, int n, IndexType linearIndex,
|
| 345 |
+
Offsets... aOffsets, Offsets... bOffsets) {
|
| 346 |
+
op(n, a.data[aOffsets]..., b.data[bOffsets]...);
|
| 347 |
+
}
|
| 348 |
+
};
|
| 349 |
+
|
| 350 |
+
template <typename Op,
|
| 351 |
+
typename scalar1,
|
| 352 |
+
typename scalar2,
|
| 353 |
+
typename IndexType,
|
| 354 |
+
int ADims, int BDims,
|
| 355 |
+
int step,
|
| 356 |
+
int max_threads_per_block=AT_APPLY_THREADS_PER_BLOCK,
|
| 357 |
+
int min_blocks_per_sm=AT_APPLY_BLOCKS_PER_SM>
|
| 358 |
+
#if __CUDA_ARCH__ >= 350 || defined(USE_ROCM)
|
| 359 |
+
C10_LAUNCH_BOUNDS_2(max_threads_per_block, min_blocks_per_sm)
|
| 360 |
+
#endif
|
| 361 |
+
__global__ void
|
| 362 |
+
kernelPointwiseApply2(detail::TensorInfo<scalar1, IndexType> a,
|
| 363 |
+
detail::TensorInfo<scalar2, IndexType> b,
|
| 364 |
+
IndexType totalElements,
|
| 365 |
+
const Op op) {
|
| 366 |
+
for (IndexType linearIndex = (blockIdx.x * blockDim.x + threadIdx.x) * step;
|
| 367 |
+
linearIndex < totalElements;
|
| 368 |
+
linearIndex += gridDim.x * blockDim.x * step) {
|
| 369 |
+
ApplyOp2<Op, scalar1, scalar2, IndexType, ADims, BDims, step>::apply(
|
| 370 |
+
a, b, op, ::min(step, static_cast<int>(totalElements - linearIndex)),
|
| 371 |
+
linearIndex);
|
| 372 |
+
}
|
| 373 |
+
}
|
| 374 |
+
|
| 375 |
+
} // anonymous namespace
|
| 376 |
+
|
| 377 |
+
template <typename scalar1, typename scalar2, int step, typename Op,
|
| 378 |
+
int max_threads_per_block=AT_APPLY_THREADS_PER_BLOCK,
|
| 379 |
+
int min_blocks_per_sm=AT_APPLY_BLOCKS_PER_SM>
|
| 380 |
+
inline bool CUDA_tensor_apply2(at::TensorBase a,
|
| 381 |
+
at::TensorBase b,
|
| 382 |
+
const Op op,
|
| 383 |
+
TensorArgType aType = TensorArgType::ReadWrite,
|
| 384 |
+
TensorArgType bType = TensorArgType::ReadOnly) {
|
| 385 |
+
TORCH_CHECK(a.device().is_cuda() && b.device().is_cuda(),
|
| 386 |
+
"CUDA_tensor_apply2: Expected tensors to have CUDA DeviceType, but got "
|
| 387 |
+
"tensors with type ", a.device().type(), " and ", b.device().type());
|
| 388 |
+
int64_t totalElements = a.numel();
|
| 389 |
+
|
| 390 |
+
if (totalElements != b.numel()) {
|
| 391 |
+
return false;
|
| 392 |
+
}
|
| 393 |
+
|
| 394 |
+
if (a.dim() > MAX_TENSORINFO_DIMS ||
|
| 395 |
+
b.dim() > MAX_TENSORINFO_DIMS) {
|
| 396 |
+
return false;
|
| 397 |
+
}
|
| 398 |
+
|
| 399 |
+
if (a.numel() == 0) {
|
| 400 |
+
// Empty tensor; do nothing
|
| 401 |
+
return true;
|
| 402 |
+
}
|
| 403 |
+
const dim3 block = getApplyBlock(max_threads_per_block);
|
| 404 |
+
|
| 405 |
+
dim3 grid;
|
| 406 |
+
auto curDevice = current_device();
|
| 407 |
+
if (curDevice == -1) return false;
|
| 408 |
+
if (!getApplyGrid<step>(totalElements, grid, curDevice, max_threads_per_block)) {
|
| 409 |
+
return false;
|
| 410 |
+
}
|
| 411 |
+
|
| 412 |
+
/*
|
| 413 |
+
Expands readable/writable tensors whose indices may be "overlapped."
|
| 414 |
+
This ensures that each element of the tensor is operated on once and only
|
| 415 |
+
once.
|
| 416 |
+
*/
|
| 417 |
+
TensorBase oldA;
|
| 418 |
+
TensorBase oldB;
|
| 419 |
+
|
| 420 |
+
if (aType == TensorArgType::ReadWrite && detail::maybeOverlappingIndices(a)) {
|
| 421 |
+
// Must perform in contiguous space
|
| 422 |
+
oldA = std::exchange(a, a.contiguous());
|
| 423 |
+
}
|
| 424 |
+
if (bType == TensorArgType::ReadWrite && detail::maybeOverlappingIndices(b)) {
|
| 425 |
+
// Must perform in contiguous space
|
| 426 |
+
oldB = std::exchange(b, b.contiguous());
|
| 427 |
+
}
|
| 428 |
+
|
| 429 |
+
// It is possible that the tensor dimensions are able to be collapsed,
|
| 430 |
+
// and thus we can reduce the actual code complexity of the copy by
|
| 431 |
+
// exploiting this knowledge statically, since the div/mod is the
|
| 432 |
+
// most expensive part of the operation, more so than memory accesses.
|
| 433 |
+
// For instance, when copying a non-contiguous to a contiguous tensor
|
| 434 |
+
// (or vice versa), the contiguous tensor can be collapsed to one
|
| 435 |
+
// dimension, and the loop to translate the linear index to the array
|
| 436 |
+
// index can be similarly collapsed. That is what this unrolling is for.
|
| 437 |
+
|
| 438 |
+
#define HANDLE_CASE(TYPE, A, B) \
|
| 439 |
+
kernelPointwiseApply2<Op, \
|
| 440 |
+
scalar1, \
|
| 441 |
+
scalar2, \
|
| 442 |
+
TYPE, A, B, step, \
|
| 443 |
+
max_threads_per_block, \
|
| 444 |
+
min_blocks_per_sm> \
|
| 445 |
+
<<<grid, block, 0, at::cuda::getCurrentCUDAStream(curDevice)>>>( \
|
| 446 |
+
aInfo, bInfo, static_cast<TYPE>(totalElements), op); \
|
| 447 |
+
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
| 448 |
+
|
| 449 |
+
#define HANDLE_B_CASE(TYPE, A, B) { \
|
| 450 |
+
switch (B) { \
|
| 451 |
+
case 1: \
|
| 452 |
+
HANDLE_CASE(TYPE, A, 1); \
|
| 453 |
+
break; \
|
| 454 |
+
case 2: \
|
| 455 |
+
HANDLE_CASE(TYPE, A, 2); \
|
| 456 |
+
break; \
|
| 457 |
+
default: \
|
| 458 |
+
HANDLE_CASE(TYPE, A, -1); \
|
| 459 |
+
break; \
|
| 460 |
+
} \
|
| 461 |
+
}
|
| 462 |
+
|
| 463 |
+
#define HANDLE_A_CASE(TYPE, A, B) { \
|
| 464 |
+
switch (A) { \
|
| 465 |
+
case 1: \
|
| 466 |
+
HANDLE_B_CASE(TYPE, 1, B); \
|
| 467 |
+
break; \
|
| 468 |
+
case 2: \
|
| 469 |
+
HANDLE_B_CASE(TYPE, 2, B); \
|
| 470 |
+
break; \
|
| 471 |
+
default: \
|
| 472 |
+
HANDLE_B_CASE(TYPE, -1, B); \
|
| 473 |
+
break; \
|
| 474 |
+
} \
|
| 475 |
+
}
|
| 476 |
+
|
| 477 |
+
if (detail::canUse32BitIndexMath(a) &&
|
| 478 |
+
detail::canUse32BitIndexMath(b)) {
|
| 479 |
+
detail::TensorInfo<scalar1, unsigned int> aInfo =
|
| 480 |
+
detail::getTensorInfo<scalar1, unsigned int>(a);
|
| 481 |
+
|
| 482 |
+
detail::TensorInfo<scalar2, unsigned int> bInfo =
|
| 483 |
+
detail::getTensorInfo<scalar2, unsigned int>(b);
|
| 484 |
+
rearrangeDims(&aInfo, &bInfo);
|
| 485 |
+
aInfo.collapseDims();
|
| 486 |
+
bInfo.collapseDims();
|
| 487 |
+
|
| 488 |
+
HANDLE_A_CASE(unsigned int, aInfo.dims, bInfo.dims);
|
| 489 |
+
} else {
|
| 490 |
+
detail::TensorInfo<scalar1, uint64_t> aInfo =
|
| 491 |
+
detail::getTensorInfo<scalar1, uint64_t>(a);
|
| 492 |
+
|
| 493 |
+
detail::TensorInfo<scalar2, uint64_t> bInfo =
|
| 494 |
+
detail::getTensorInfo<scalar2, uint64_t>(b);
|
| 495 |
+
rearrangeDims(&aInfo, &bInfo);
|
| 496 |
+
aInfo.collapseDims();
|
| 497 |
+
bInfo.collapseDims();
|
| 498 |
+
|
| 499 |
+
/*
|
| 500 |
+
Only instantiates the all 1D special case and the fallback all nD case for
|
| 501 |
+
large (64-bit indexed) tensors to reduce compilation time.
|
| 502 |
+
*/
|
| 503 |
+
if (aInfo.dims == 1 && bInfo.dims == 1) {
|
| 504 |
+
HANDLE_CASE(uint64_t, 1, 1);
|
| 505 |
+
} else {
|
| 506 |
+
HANDLE_CASE(uint64_t, -1, -1);
|
| 507 |
+
}
|
| 508 |
+
}
|
| 509 |
+
#undef HANDLE_CASE
|
| 510 |
+
#undef HANDLE_B_CASE
|
| 511 |
+
#undef HANDLE_A_CASE
|
| 512 |
+
|
| 513 |
+
if (oldA.defined()) {
|
| 514 |
+
at::native::copy_ignoring_overlaps(oldA, a);
|
| 515 |
+
}
|
| 516 |
+
|
| 517 |
+
if (oldB.defined()) {
|
| 518 |
+
at::native::copy_ignoring_overlaps(oldB, b);
|
| 519 |
+
}
|
| 520 |
+
|
| 521 |
+
return true;
|
| 522 |
+
}
|
| 523 |
+
|
| 524 |
+
/* Provides default step = 1 to CUDA_tensor_apply2. */
|
| 525 |
+
template <typename scalar1, typename scalar2, typename Op,
|
| 526 |
+
int max_threads_per_block=AT_APPLY_THREADS_PER_BLOCK,
|
| 527 |
+
int min_blocks_per_sm=AT_APPLY_BLOCKS_PER_SM>
|
| 528 |
+
inline bool CUDA_tensor_apply2(const at::TensorBase &a,
|
| 529 |
+
const at::TensorBase &b,
|
| 530 |
+
const Op op,
|
| 531 |
+
TensorArgType aType = TensorArgType::ReadWrite,
|
| 532 |
+
TensorArgType bType = TensorArgType::ReadOnly) {
|
| 533 |
+
return CUDA_tensor_apply2<scalar1, scalar2, 1, Op,
|
| 534 |
+
max_threads_per_block, min_blocks_per_sm>(a, b, op, aType, bType);
|
| 535 |
+
}
|
| 536 |
+
|
| 537 |
+
} // namespace at::cuda
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/CUDAGeneratorImpl.h
ADDED
|
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/core/Generator.h>
|
| 4 |
+
#include <ATen/cuda/PhiloxCudaState.h>
|
| 5 |
+
#include <ATen/Context.h>
|
| 6 |
+
#include <limits>
|
| 7 |
+
#include <atomic>
|
| 8 |
+
|
| 9 |
+
namespace at {
|
| 10 |
+
/**
|
| 11 |
+
* Note [CUDA Graph-safe RNG states]
|
| 12 |
+
* ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 13 |
+
*
|
| 14 |
+
* Strategy:
|
| 15 |
+
* ~~~~~~~~~
|
| 16 |
+
* (It helps to look at
|
| 17 |
+
* cuda/detail/PhiloxCudaStateRaw.cuh and
|
| 18 |
+
* cuda/detail/UnpackRaw.cuh
|
| 19 |
+
* while you read this.)
|
| 20 |
+
*
|
| 21 |
+
* A CUDA graph containing multiple RNG ops behaves like a
|
| 22 |
+
* single giant kernel from the perspective of ops external
|
| 23 |
+
* to the graph. During graph capture, logic in CUDAGeneratorImpl
|
| 24 |
+
* records the total of all offset increments that occur in the
|
| 25 |
+
* graphed region, and records the final total as the offset for
|
| 26 |
+
* the entire graph.
|
| 27 |
+
*
|
| 28 |
+
* When the graph reruns, the logic that reruns it
|
| 29 |
+
* increments this device's CUDA generator's offset
|
| 30 |
+
* by that total.
|
| 31 |
+
*
|
| 32 |
+
* Meanwhile, within the graph, at capture time, instead of
|
| 33 |
+
* populating PhiloxCudaStates with the uint64_t offset pulled
|
| 34 |
+
* directly from the global state, PhiloxCudaState uses a pointer
|
| 35 |
+
* to a one-element stream-local int64_t device tensor
|
| 36 |
+
* holding an initial offset value, and a uint64_t holding an
|
| 37 |
+
* intra-graph offset. (The intra-graph offset starts from zero
|
| 38 |
+
* when capture begins.) In each consumer kernel,
|
| 39 |
+
* at::cuda::philox::unpack computes the offset to use for this kernel
|
| 40 |
+
* as intra-graph offset + *initial offset.
|
| 41 |
+
*
|
| 42 |
+
* When the graph reruns, the logic that reruns it first
|
| 43 |
+
* fill_s the initial offset tensor with this device's
|
| 44 |
+
* CUDA generator's current offset.
|
| 45 |
+
*
|
| 46 |
+
* The control flow above ensures graphed execution is bitwise
|
| 47 |
+
* identical to eager execution as long as RNG ops are enqueued
|
| 48 |
+
* from a single thread, even if RNG ops and graphs containing
|
| 49 |
+
* RNG ops are enqueued and run simultaneously on multiple streams.
|
| 50 |
+
*
|
| 51 |
+
* Usage:
|
| 52 |
+
* ~~~~~~
|
| 53 |
+
* PhiloxCudaState in this file, and unpack() in
|
| 54 |
+
* cuda/CUDAGraphsUtils.cuh allow non-divergent use of
|
| 55 |
+
* CUDAGeneratorImpl whether graph capture is underway or not.
|
| 56 |
+
*
|
| 57 |
+
* Each PhiloxCudaState instance should be used for one and only one
|
| 58 |
+
* consumer kernel.
|
| 59 |
+
*
|
| 60 |
+
* Example (see e.g. native/cuda/Dropout.cu):
|
| 61 |
+
*
|
| 62 |
+
* #include <ATen/cuda/CUDAGeneratorImpl.h>
|
| 63 |
+
* #include <ATen/cuda/CUDAGraphsUtils.cuh>
|
| 64 |
+
*
|
| 65 |
+
* __global__ void kernel(..., PhiloxCudaState philox_args) {
|
| 66 |
+
* auto seeds = at::cuda::philox::unpack(philox_args);
|
| 67 |
+
* IndexType idx = blockIdx.x * blockDim.x + threadIdx.x;
|
| 68 |
+
* curandStatePhilox4_32_10_t state;
|
| 69 |
+
* curand_init(std::get<0>(seeds), // seed
|
| 70 |
+
* idx, // per-thread subsequence
|
| 71 |
+
* std::get<1>(seeds), // offset in subsequence
|
| 72 |
+
* &state);
|
| 73 |
+
* ...
|
| 74 |
+
* }
|
| 75 |
+
*
|
| 76 |
+
* host_caller(...) {
|
| 77 |
+
* PhiloxCudaState rng_engine_inputs;
|
| 78 |
+
* {
|
| 79 |
+
* // See Note [Acquire lock when using random generators]
|
| 80 |
+
* std::lock_guard<std::mutex> lock(gen->mutex_);
|
| 81 |
+
*
|
| 82 |
+
* // gen could be HostState or DevState here! No divergent code needed!
|
| 83 |
+
* rng_engine_inputs = gen->philox_cuda_state(offset_increment);
|
| 84 |
+
* }
|
| 85 |
+
* kernel<<<...>>>(..., rng_engine_inputs);
|
| 86 |
+
* }
|
| 87 |
+
*
|
| 88 |
+
*/
|
| 89 |
+
|
| 90 |
+
struct TORCH_CUDA_CPP_API CUDAGeneratorImpl : public c10::GeneratorImpl {
|
| 91 |
+
// Constructors
|
| 92 |
+
CUDAGeneratorImpl(DeviceIndex device_index = -1);
|
| 93 |
+
~CUDAGeneratorImpl() override = default;
|
| 94 |
+
|
| 95 |
+
// CUDAGeneratorImpl methods
|
| 96 |
+
std::shared_ptr<CUDAGeneratorImpl> clone() const;
|
| 97 |
+
void set_current_seed(uint64_t seed) override;
|
| 98 |
+
void set_offset(uint64_t offset) override;
|
| 99 |
+
uint64_t get_offset() const override;
|
| 100 |
+
uint64_t current_seed() const override;
|
| 101 |
+
uint64_t seed() override;
|
| 102 |
+
void set_state(const c10::TensorImpl& new_state) override;
|
| 103 |
+
c10::intrusive_ptr<c10::TensorImpl> get_state() const override;
|
| 104 |
+
void set_philox_offset_per_thread(uint64_t offset);
|
| 105 |
+
uint64_t philox_offset_per_thread() const;
|
| 106 |
+
void capture_prologue(int64_t* seed_extragraph, int64_t* offset_extragraph);
|
| 107 |
+
uint64_t capture_epilogue();
|
| 108 |
+
PhiloxCudaState philox_cuda_state(uint64_t increment);
|
| 109 |
+
|
| 110 |
+
bool reset_rnn_state() {
|
| 111 |
+
return !no_reset_rnn_state_.test_and_set();
|
| 112 |
+
}
|
| 113 |
+
|
| 114 |
+
// Temporarily accommodates call sites that use philox_engine_inputs.
|
| 115 |
+
// Allows incremental refactor of call sites to use philox_cuda_state.
|
| 116 |
+
std::pair<uint64_t, uint64_t> philox_engine_inputs(uint64_t increment);
|
| 117 |
+
|
| 118 |
+
static c10::DeviceType device_type();
|
| 119 |
+
|
| 120 |
+
private:
|
| 121 |
+
CUDAGeneratorImpl* clone_impl() const override;
|
| 122 |
+
uint64_t seed_ = default_rng_seed_val;
|
| 123 |
+
uint64_t philox_offset_per_thread_ = 0;
|
| 124 |
+
int64_t* seed_extragraph_{};
|
| 125 |
+
int64_t* offset_extragraph_{};
|
| 126 |
+
uint32_t offset_intragraph_ = 0;
|
| 127 |
+
bool graph_expects_this_gen_ = false;
|
| 128 |
+
std::atomic_flag no_reset_rnn_state_;
|
| 129 |
+
};
|
| 130 |
+
|
| 131 |
+
namespace cuda::detail {
|
| 132 |
+
|
| 133 |
+
TORCH_CUDA_CPP_API const Generator& getDefaultCUDAGenerator(
|
| 134 |
+
DeviceIndex device_index = -1);
|
| 135 |
+
TORCH_CUDA_CPP_API Generator createCUDAGenerator(DeviceIndex device_index = -1);
|
| 136 |
+
|
| 137 |
+
} // namespace cuda::detail
|
| 138 |
+
} // namespace at
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/CUDAGraph.h
ADDED
|
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/Tensor.h>
|
| 4 |
+
#include <c10/core/Device.h>
|
| 5 |
+
#include <c10/cuda/CUDAGraphsC10Utils.h>
|
| 6 |
+
#include <c10/cuda/CUDAStream.h>
|
| 7 |
+
|
| 8 |
+
#include <mutex>
|
| 9 |
+
|
| 10 |
+
namespace at {
|
| 11 |
+
|
| 12 |
+
struct CUDAGeneratorImpl;
|
| 13 |
+
|
| 14 |
+
namespace cuda {
|
| 15 |
+
|
| 16 |
+
// Standalone way to get a unique mempool id usable as a pool=... argument
|
| 17 |
+
// to CUDAGraph::capture_begin
|
| 18 |
+
TORCH_CUDA_CPP_API MempoolId_t graph_pool_handle();
|
| 19 |
+
|
| 20 |
+
struct TORCH_CUDA_CPP_API CUDAGraph {
|
| 21 |
+
CUDAGraph();
|
| 22 |
+
~CUDAGraph();
|
| 23 |
+
|
| 24 |
+
static void inc_pending_event_queries();
|
| 25 |
+
static void dec_pending_event_queries();
|
| 26 |
+
static int num_pending_event_queries();
|
| 27 |
+
void capture_begin(MempoolId_t pool={0, 0}, cudaStreamCaptureMode capture_mode = cudaStreamCaptureModeGlobal);
|
| 28 |
+
void capture_end();
|
| 29 |
+
void replay();
|
| 30 |
+
void reset();
|
| 31 |
+
MempoolId_t pool();
|
| 32 |
+
void enable_debug_mode();
|
| 33 |
+
void debug_dump(const std::string& debug_path);
|
| 34 |
+
|
| 35 |
+
protected:
|
| 36 |
+
#if !defined(USE_ROCM) || ROCM_VERSION >= 50300
|
| 37 |
+
cudaGraph_t graph_ = NULL;
|
| 38 |
+
cudaGraphExec_t graph_exec_ = NULL;
|
| 39 |
+
#endif
|
| 40 |
+
|
| 41 |
+
static std::atomic<int> pending_event_queries;
|
| 42 |
+
|
| 43 |
+
// internal states so reset() can do its best cleaning up
|
| 44 |
+
// Set to true in capture_end if cudaStreamEndCapture succeeded
|
| 45 |
+
// Set back to false soon after, when graph_ is consumed by cudaGraphInstantiate
|
| 46 |
+
// to create graph_exec_, then graph_ is deleted
|
| 47 |
+
bool has_graph_ = false;
|
| 48 |
+
// Set to true in capture_end if cudaGraphInstantiate succeeded
|
| 49 |
+
bool has_graph_exec_ = false;
|
| 50 |
+
|
| 51 |
+
// uuid of this instance's current capture, used to
|
| 52 |
+
// specify the pool.
|
| 53 |
+
CaptureId_t id_;
|
| 54 |
+
|
| 55 |
+
// the ID assigned by cuda during graph capture,
|
| 56 |
+
// used to identify when a stream is participating in capture
|
| 57 |
+
CaptureId_t capture_id_ = -1;
|
| 58 |
+
|
| 59 |
+
// uuid used to request a particular private mempool from CUDACachingAllocator.
|
| 60 |
+
// By default, this will be set to {id_, 0}.
|
| 61 |
+
//
|
| 62 |
+
// If capture_begin is called with "pool=other_graph.pool()", this graph's mempool_id_
|
| 63 |
+
// will be set to the other graph's mempool_id_, and therefore share a mempool with the
|
| 64 |
+
// other graph.
|
| 65 |
+
//
|
| 66 |
+
// If capture_begin is called with "pool=handle" where "handle" came from graph_pool_handle(),
|
| 67 |
+
// it will share a mempool with any other captures that used "pool=handle".
|
| 68 |
+
//
|
| 69 |
+
// Sharing a mempool across graphs saves memory, and it's safe if you
|
| 70 |
+
// know you'll replay those graphs in the same order you captured them.
|
| 71 |
+
MempoolId_t mempool_id_;
|
| 72 |
+
|
| 73 |
+
// Stream on which capture began
|
| 74 |
+
at::cuda::CUDAStream capture_stream_;
|
| 75 |
+
|
| 76 |
+
// Default generator on device where capture began
|
| 77 |
+
at::CUDAGeneratorImpl* capture_gen_;
|
| 78 |
+
|
| 79 |
+
// Device where capture occurred. Right now, for simplicity, we require all ops
|
| 80 |
+
// in a capture to run on the same device, but this is a limitation of CUDAGraph,
|
| 81 |
+
// not CUDA itself. We can straightforwardly modify CUDAGraph to support multi-device
|
| 82 |
+
// captures if needed.
|
| 83 |
+
int capture_dev_;
|
| 84 |
+
|
| 85 |
+
// RNG state trackers
|
| 86 |
+
at::Tensor seed_extragraph_;
|
| 87 |
+
at::Tensor offset_extragraph_;
|
| 88 |
+
uint64_t wholegraph_increment_;
|
| 89 |
+
};
|
| 90 |
+
|
| 91 |
+
} // namespace cuda
|
| 92 |
+
} // namespace at
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/CUDAGraphsUtils.cuh
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/cuda/CUDAGeneratorImpl.h>
|
| 4 |
+
#include <ATen/cuda/CUDAEvent.h>
|
| 5 |
+
#include <ATen/cuda/PhiloxUtils.cuh>
|
| 6 |
+
#include <ATen/cuda/detail/CUDAHooks.h>
|
| 7 |
+
#include <ATen/detail/CUDAHooksInterface.h>
|
| 8 |
+
#include <c10/core/StreamGuard.h>
|
| 9 |
+
#include <c10/cuda/CUDAGraphsC10Utils.h>
|
| 10 |
+
#include <c10/cuda/CUDAGuard.h>
|
| 11 |
+
|
| 12 |
+
// c10/cuda/CUDAGraphsC10Utils.h has utils used by both c10 and aten.
|
| 13 |
+
// This file adds utils used by aten only.
|
| 14 |
+
|
| 15 |
+
namespace at::cuda {
|
| 16 |
+
|
| 17 |
+
using CaptureId_t = c10::cuda::CaptureId_t;
|
| 18 |
+
using CaptureStatus = c10::cuda::CaptureStatus;
|
| 19 |
+
|
| 20 |
+
// Use this version where you don't want to create a CUDA context if none exists.
|
| 21 |
+
inline CaptureStatus currentStreamCaptureStatus() {
|
| 22 |
+
#if !defined(USE_ROCM) || ROCM_VERSION >= 50300
|
| 23 |
+
// don't create a context if we don't have to
|
| 24 |
+
if (c10::cuda::hasPrimaryContext(c10::cuda::current_device())) {
|
| 25 |
+
return c10::cuda::currentStreamCaptureStatusMayInitCtx();
|
| 26 |
+
} else {
|
| 27 |
+
return CaptureStatus::None;
|
| 28 |
+
}
|
| 29 |
+
#else
|
| 30 |
+
return CaptureStatus::None;
|
| 31 |
+
#endif
|
| 32 |
+
}
|
| 33 |
+
|
| 34 |
+
inline void assertNotCapturing(std::string attempt) {
|
| 35 |
+
auto status = currentStreamCaptureStatus();
|
| 36 |
+
TORCH_CHECK(status == CaptureStatus::None,
|
| 37 |
+
attempt,
|
| 38 |
+
" during CUDA graph capture. If you need this call to be captured, "
|
| 39 |
+
"please file an issue. "
|
| 40 |
+
"Current cudaStreamCaptureStatus: ",
|
| 41 |
+
status);
|
| 42 |
+
}
|
| 43 |
+
|
| 44 |
+
inline void errorIfCapturingCudnnBenchmark(std::string version_specific) {
|
| 45 |
+
auto status = currentStreamCaptureStatus();
|
| 46 |
+
TORCH_CHECK(status == CaptureStatus::None,
|
| 47 |
+
"Current cudaStreamCaptureStatus: ",
|
| 48 |
+
status,
|
| 49 |
+
"\nCapturing ",
|
| 50 |
+
version_specific,
|
| 51 |
+
"is prohibited. Possible causes of this error:\n"
|
| 52 |
+
"1. No warmup iterations occurred before capture.\n"
|
| 53 |
+
"2. The convolutions you're trying to capture use dynamic shapes, "
|
| 54 |
+
"in which case capturing them is generally prohibited.");
|
| 55 |
+
}
|
| 56 |
+
|
| 57 |
+
} // namespace at::cuda
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/CUDASparseBlas.h
ADDED
|
@@ -0,0 +1,318 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
/*
|
| 4 |
+
Provides a subset of cuSPARSE functions as templates:
|
| 5 |
+
|
| 6 |
+
csrgeam2<scalar_t>(...)
|
| 7 |
+
|
| 8 |
+
where scalar_t is double, float, c10::complex<double> or c10::complex<float>.
|
| 9 |
+
The functions are available in at::cuda::sparse namespace.
|
| 10 |
+
*/
|
| 11 |
+
|
| 12 |
+
#include <ATen/cuda/CUDAContext.h>
|
| 13 |
+
#include <ATen/cuda/CUDASparse.h>
|
| 14 |
+
|
| 15 |
+
namespace at::cuda::sparse {
|
| 16 |
+
|
| 17 |
+
#define CUSPARSE_CSRGEAM2_BUFFERSIZE_ARGTYPES(scalar_t) \
|
| 18 |
+
cusparseHandle_t handle, int m, int n, const scalar_t *alpha, \
|
| 19 |
+
const cusparseMatDescr_t descrA, int nnzA, \
|
| 20 |
+
const scalar_t *csrSortedValA, const int *csrSortedRowPtrA, \
|
| 21 |
+
const int *csrSortedColIndA, const scalar_t *beta, \
|
| 22 |
+
const cusparseMatDescr_t descrB, int nnzB, \
|
| 23 |
+
const scalar_t *csrSortedValB, const int *csrSortedRowPtrB, \
|
| 24 |
+
const int *csrSortedColIndB, const cusparseMatDescr_t descrC, \
|
| 25 |
+
const scalar_t *csrSortedValC, const int *csrSortedRowPtrC, \
|
| 26 |
+
const int *csrSortedColIndC, size_t *pBufferSizeInBytes
|
| 27 |
+
|
| 28 |
+
template <typename scalar_t>
|
| 29 |
+
inline void csrgeam2_bufferSizeExt(
|
| 30 |
+
CUSPARSE_CSRGEAM2_BUFFERSIZE_ARGTYPES(scalar_t)) {
|
| 31 |
+
TORCH_INTERNAL_ASSERT(
|
| 32 |
+
false,
|
| 33 |
+
"at::cuda::sparse::csrgeam2_bufferSizeExt: not implemented for ",
|
| 34 |
+
typeid(scalar_t).name());
|
| 35 |
+
}
|
| 36 |
+
|
| 37 |
+
template <>
|
| 38 |
+
void csrgeam2_bufferSizeExt<float>(
|
| 39 |
+
CUSPARSE_CSRGEAM2_BUFFERSIZE_ARGTYPES(float));
|
| 40 |
+
template <>
|
| 41 |
+
void csrgeam2_bufferSizeExt<double>(
|
| 42 |
+
CUSPARSE_CSRGEAM2_BUFFERSIZE_ARGTYPES(double));
|
| 43 |
+
template <>
|
| 44 |
+
void csrgeam2_bufferSizeExt<c10::complex<float>>(
|
| 45 |
+
CUSPARSE_CSRGEAM2_BUFFERSIZE_ARGTYPES(c10::complex<float>));
|
| 46 |
+
template <>
|
| 47 |
+
void csrgeam2_bufferSizeExt<c10::complex<double>>(
|
| 48 |
+
CUSPARSE_CSRGEAM2_BUFFERSIZE_ARGTYPES(c10::complex<double>));
|
| 49 |
+
|
| 50 |
+
#define CUSPARSE_CSRGEAM2_NNZ_ARGTYPES() \
|
| 51 |
+
cusparseHandle_t handle, int m, int n, const cusparseMatDescr_t descrA, \
|
| 52 |
+
int nnzA, const int *csrSortedRowPtrA, const int *csrSortedColIndA, \
|
| 53 |
+
const cusparseMatDescr_t descrB, int nnzB, const int *csrSortedRowPtrB, \
|
| 54 |
+
const int *csrSortedColIndB, const cusparseMatDescr_t descrC, \
|
| 55 |
+
int *csrSortedRowPtrC, int *nnzTotalDevHostPtr, void *workspace
|
| 56 |
+
|
| 57 |
+
template <typename scalar_t>
|
| 58 |
+
inline void csrgeam2Nnz(CUSPARSE_CSRGEAM2_NNZ_ARGTYPES()) {
|
| 59 |
+
TORCH_CUDASPARSE_CHECK(cusparseXcsrgeam2Nnz(
|
| 60 |
+
handle,
|
| 61 |
+
m,
|
| 62 |
+
n,
|
| 63 |
+
descrA,
|
| 64 |
+
nnzA,
|
| 65 |
+
csrSortedRowPtrA,
|
| 66 |
+
csrSortedColIndA,
|
| 67 |
+
descrB,
|
| 68 |
+
nnzB,
|
| 69 |
+
csrSortedRowPtrB,
|
| 70 |
+
csrSortedColIndB,
|
| 71 |
+
descrC,
|
| 72 |
+
csrSortedRowPtrC,
|
| 73 |
+
nnzTotalDevHostPtr,
|
| 74 |
+
workspace));
|
| 75 |
+
}
|
| 76 |
+
|
| 77 |
+
#define CUSPARSE_CSRGEAM2_ARGTYPES(scalar_t) \
|
| 78 |
+
cusparseHandle_t handle, int m, int n, const scalar_t *alpha, \
|
| 79 |
+
const cusparseMatDescr_t descrA, int nnzA, \
|
| 80 |
+
const scalar_t *csrSortedValA, const int *csrSortedRowPtrA, \
|
| 81 |
+
const int *csrSortedColIndA, const scalar_t *beta, \
|
| 82 |
+
const cusparseMatDescr_t descrB, int nnzB, \
|
| 83 |
+
const scalar_t *csrSortedValB, const int *csrSortedRowPtrB, \
|
| 84 |
+
const int *csrSortedColIndB, const cusparseMatDescr_t descrC, \
|
| 85 |
+
scalar_t *csrSortedValC, int *csrSortedRowPtrC, int *csrSortedColIndC, \
|
| 86 |
+
void *pBuffer
|
| 87 |
+
|
| 88 |
+
template <typename scalar_t>
|
| 89 |
+
inline void csrgeam2(CUSPARSE_CSRGEAM2_ARGTYPES(scalar_t)) {
|
| 90 |
+
TORCH_INTERNAL_ASSERT(
|
| 91 |
+
false,
|
| 92 |
+
"at::cuda::sparse::csrgeam2: not implemented for ",
|
| 93 |
+
typeid(scalar_t).name());
|
| 94 |
+
}
|
| 95 |
+
|
| 96 |
+
template <>
|
| 97 |
+
void csrgeam2<float>(CUSPARSE_CSRGEAM2_ARGTYPES(float));
|
| 98 |
+
template <>
|
| 99 |
+
void csrgeam2<double>(CUSPARSE_CSRGEAM2_ARGTYPES(double));
|
| 100 |
+
template <>
|
| 101 |
+
void csrgeam2<c10::complex<float>>(
|
| 102 |
+
CUSPARSE_CSRGEAM2_ARGTYPES(c10::complex<float>));
|
| 103 |
+
template <>
|
| 104 |
+
void csrgeam2<c10::complex<double>>(
|
| 105 |
+
CUSPARSE_CSRGEAM2_ARGTYPES(c10::complex<double>));
|
| 106 |
+
|
| 107 |
+
#define CUSPARSE_BSRMM_ARGTYPES(scalar_t) \
|
| 108 |
+
cusparseHandle_t handle, cusparseDirection_t dirA, \
|
| 109 |
+
cusparseOperation_t transA, cusparseOperation_t transB, int mb, int n, \
|
| 110 |
+
int kb, int nnzb, const scalar_t *alpha, \
|
| 111 |
+
const cusparseMatDescr_t descrA, const scalar_t *bsrValA, \
|
| 112 |
+
const int *bsrRowPtrA, const int *bsrColIndA, int blockDim, \
|
| 113 |
+
const scalar_t *B, int ldb, const scalar_t *beta, scalar_t *C, int ldc
|
| 114 |
+
|
| 115 |
+
template <typename scalar_t>
|
| 116 |
+
inline void bsrmm(CUSPARSE_BSRMM_ARGTYPES(scalar_t)) {
|
| 117 |
+
TORCH_INTERNAL_ASSERT(
|
| 118 |
+
false,
|
| 119 |
+
"at::cuda::sparse::bsrmm: not implemented for ",
|
| 120 |
+
typeid(scalar_t).name());
|
| 121 |
+
}
|
| 122 |
+
|
| 123 |
+
template <>
|
| 124 |
+
void bsrmm<float>(CUSPARSE_BSRMM_ARGTYPES(float));
|
| 125 |
+
template <>
|
| 126 |
+
void bsrmm<double>(CUSPARSE_BSRMM_ARGTYPES(double));
|
| 127 |
+
template <>
|
| 128 |
+
void bsrmm<c10::complex<float>>(CUSPARSE_BSRMM_ARGTYPES(c10::complex<float>));
|
| 129 |
+
template <>
|
| 130 |
+
void bsrmm<c10::complex<double>>(CUSPARSE_BSRMM_ARGTYPES(c10::complex<double>));
|
| 131 |
+
|
| 132 |
+
#define CUSPARSE_BSRMV_ARGTYPES(scalar_t) \
|
| 133 |
+
cusparseHandle_t handle, cusparseDirection_t dirA, \
|
| 134 |
+
cusparseOperation_t transA, int mb, int nb, int nnzb, \
|
| 135 |
+
const scalar_t *alpha, const cusparseMatDescr_t descrA, \
|
| 136 |
+
const scalar_t *bsrValA, const int *bsrRowPtrA, const int *bsrColIndA, \
|
| 137 |
+
int blockDim, const scalar_t *x, const scalar_t *beta, scalar_t *y
|
| 138 |
+
|
| 139 |
+
template <typename scalar_t>
|
| 140 |
+
inline void bsrmv(CUSPARSE_BSRMV_ARGTYPES(scalar_t)) {
|
| 141 |
+
TORCH_INTERNAL_ASSERT(
|
| 142 |
+
false,
|
| 143 |
+
"at::cuda::sparse::bsrmv: not implemented for ",
|
| 144 |
+
typeid(scalar_t).name());
|
| 145 |
+
}
|
| 146 |
+
|
| 147 |
+
template <>
|
| 148 |
+
void bsrmv<float>(CUSPARSE_BSRMV_ARGTYPES(float));
|
| 149 |
+
template <>
|
| 150 |
+
void bsrmv<double>(CUSPARSE_BSRMV_ARGTYPES(double));
|
| 151 |
+
template <>
|
| 152 |
+
void bsrmv<c10::complex<float>>(CUSPARSE_BSRMV_ARGTYPES(c10::complex<float>));
|
| 153 |
+
template <>
|
| 154 |
+
void bsrmv<c10::complex<double>>(CUSPARSE_BSRMV_ARGTYPES(c10::complex<double>));
|
| 155 |
+
|
| 156 |
+
#if AT_USE_HIPSPARSE_TRIANGULAR_SOLVE()
|
| 157 |
+
|
| 158 |
+
#define CUSPARSE_BSRSV2_BUFFER_ARGTYPES(scalar_t) \
|
| 159 |
+
cusparseHandle_t handle, cusparseDirection_t dirA, \
|
| 160 |
+
cusparseOperation_t transA, int mb, int nnzb, \
|
| 161 |
+
const cusparseMatDescr_t descrA, scalar_t *bsrValA, \
|
| 162 |
+
const int *bsrRowPtrA, const int *bsrColIndA, int blockDim, \
|
| 163 |
+
bsrsv2Info_t info, int *pBufferSizeInBytes
|
| 164 |
+
|
| 165 |
+
template <typename scalar_t>
|
| 166 |
+
inline void bsrsv2_bufferSize(CUSPARSE_BSRSV2_BUFFER_ARGTYPES(scalar_t)) {
|
| 167 |
+
TORCH_INTERNAL_ASSERT(
|
| 168 |
+
false,
|
| 169 |
+
"at::cuda::sparse::bsrsv2_bufferSize: not implemented for ",
|
| 170 |
+
typeid(scalar_t).name());
|
| 171 |
+
}
|
| 172 |
+
|
| 173 |
+
template <>
|
| 174 |
+
void bsrsv2_bufferSize<float>(CUSPARSE_BSRSV2_BUFFER_ARGTYPES(float));
|
| 175 |
+
template <>
|
| 176 |
+
void bsrsv2_bufferSize<double>(CUSPARSE_BSRSV2_BUFFER_ARGTYPES(double));
|
| 177 |
+
template <>
|
| 178 |
+
void bsrsv2_bufferSize<c10::complex<float>>(
|
| 179 |
+
CUSPARSE_BSRSV2_BUFFER_ARGTYPES(c10::complex<float>));
|
| 180 |
+
template <>
|
| 181 |
+
void bsrsv2_bufferSize<c10::complex<double>>(
|
| 182 |
+
CUSPARSE_BSRSV2_BUFFER_ARGTYPES(c10::complex<double>));
|
| 183 |
+
|
| 184 |
+
#define CUSPARSE_BSRSV2_ANALYSIS_ARGTYPES(scalar_t) \
|
| 185 |
+
cusparseHandle_t handle, cusparseDirection_t dirA, \
|
| 186 |
+
cusparseOperation_t transA, int mb, int nnzb, \
|
| 187 |
+
const cusparseMatDescr_t descrA, const scalar_t *bsrValA, \
|
| 188 |
+
const int *bsrRowPtrA, const int *bsrColIndA, int blockDim, \
|
| 189 |
+
bsrsv2Info_t info, cusparseSolvePolicy_t policy, void *pBuffer
|
| 190 |
+
|
| 191 |
+
template <typename scalar_t>
|
| 192 |
+
inline void bsrsv2_analysis(CUSPARSE_BSRSV2_ANALYSIS_ARGTYPES(scalar_t)) {
|
| 193 |
+
TORCH_INTERNAL_ASSERT(
|
| 194 |
+
false,
|
| 195 |
+
"at::cuda::sparse::bsrsv2_analysis: not implemented for ",
|
| 196 |
+
typeid(scalar_t).name());
|
| 197 |
+
}
|
| 198 |
+
|
| 199 |
+
template <>
|
| 200 |
+
void bsrsv2_analysis<float>(CUSPARSE_BSRSV2_ANALYSIS_ARGTYPES(float));
|
| 201 |
+
template <>
|
| 202 |
+
void bsrsv2_analysis<double>(CUSPARSE_BSRSV2_ANALYSIS_ARGTYPES(double));
|
| 203 |
+
template <>
|
| 204 |
+
void bsrsv2_analysis<c10::complex<float>>(
|
| 205 |
+
CUSPARSE_BSRSV2_ANALYSIS_ARGTYPES(c10::complex<float>));
|
| 206 |
+
template <>
|
| 207 |
+
void bsrsv2_analysis<c10::complex<double>>(
|
| 208 |
+
CUSPARSE_BSRSV2_ANALYSIS_ARGTYPES(c10::complex<double>));
|
| 209 |
+
|
| 210 |
+
#define CUSPARSE_BSRSV2_SOLVE_ARGTYPES(scalar_t) \
|
| 211 |
+
cusparseHandle_t handle, cusparseDirection_t dirA, \
|
| 212 |
+
cusparseOperation_t transA, int mb, int nnzb, const scalar_t *alpha, \
|
| 213 |
+
const cusparseMatDescr_t descrA, const scalar_t *bsrValA, \
|
| 214 |
+
const int *bsrRowPtrA, const int *bsrColIndA, int blockDim, \
|
| 215 |
+
bsrsv2Info_t info, const scalar_t *x, scalar_t *y, \
|
| 216 |
+
cusparseSolvePolicy_t policy, void *pBuffer
|
| 217 |
+
|
| 218 |
+
template <typename scalar_t>
|
| 219 |
+
inline void bsrsv2_solve(CUSPARSE_BSRSV2_SOLVE_ARGTYPES(scalar_t)) {
|
| 220 |
+
TORCH_INTERNAL_ASSERT(
|
| 221 |
+
false,
|
| 222 |
+
"at::cuda::sparse::bsrsv2_solve: not implemented for ",
|
| 223 |
+
typeid(scalar_t).name());
|
| 224 |
+
}
|
| 225 |
+
|
| 226 |
+
template <>
|
| 227 |
+
void bsrsv2_solve<float>(CUSPARSE_BSRSV2_SOLVE_ARGTYPES(float));
|
| 228 |
+
template <>
|
| 229 |
+
void bsrsv2_solve<double>(CUSPARSE_BSRSV2_SOLVE_ARGTYPES(double));
|
| 230 |
+
template <>
|
| 231 |
+
void bsrsv2_solve<c10::complex<float>>(
|
| 232 |
+
CUSPARSE_BSRSV2_SOLVE_ARGTYPES(c10::complex<float>));
|
| 233 |
+
template <>
|
| 234 |
+
void bsrsv2_solve<c10::complex<double>>(
|
| 235 |
+
CUSPARSE_BSRSV2_SOLVE_ARGTYPES(c10::complex<double>));
|
| 236 |
+
|
| 237 |
+
#define CUSPARSE_BSRSM2_BUFFER_ARGTYPES(scalar_t) \
|
| 238 |
+
cusparseHandle_t handle, cusparseDirection_t dirA, \
|
| 239 |
+
cusparseOperation_t transA, cusparseOperation_t transX, int mb, int n, \
|
| 240 |
+
int nnzb, const cusparseMatDescr_t descrA, scalar_t *bsrValA, \
|
| 241 |
+
const int *bsrRowPtrA, const int *bsrColIndA, int blockDim, \
|
| 242 |
+
bsrsm2Info_t info, int *pBufferSizeInBytes
|
| 243 |
+
|
| 244 |
+
template <typename scalar_t>
|
| 245 |
+
inline void bsrsm2_bufferSize(CUSPARSE_BSRSM2_BUFFER_ARGTYPES(scalar_t)) {
|
| 246 |
+
TORCH_INTERNAL_ASSERT(
|
| 247 |
+
false,
|
| 248 |
+
"at::cuda::sparse::bsrsm2_bufferSize: not implemented for ",
|
| 249 |
+
typeid(scalar_t).name());
|
| 250 |
+
}
|
| 251 |
+
|
| 252 |
+
template <>
|
| 253 |
+
void bsrsm2_bufferSize<float>(CUSPARSE_BSRSM2_BUFFER_ARGTYPES(float));
|
| 254 |
+
template <>
|
| 255 |
+
void bsrsm2_bufferSize<double>(CUSPARSE_BSRSM2_BUFFER_ARGTYPES(double));
|
| 256 |
+
template <>
|
| 257 |
+
void bsrsm2_bufferSize<c10::complex<float>>(
|
| 258 |
+
CUSPARSE_BSRSM2_BUFFER_ARGTYPES(c10::complex<float>));
|
| 259 |
+
template <>
|
| 260 |
+
void bsrsm2_bufferSize<c10::complex<double>>(
|
| 261 |
+
CUSPARSE_BSRSM2_BUFFER_ARGTYPES(c10::complex<double>));
|
| 262 |
+
|
| 263 |
+
#define CUSPARSE_BSRSM2_ANALYSIS_ARGTYPES(scalar_t) \
|
| 264 |
+
cusparseHandle_t handle, cusparseDirection_t dirA, \
|
| 265 |
+
cusparseOperation_t transA, cusparseOperation_t transX, int mb, int n, \
|
| 266 |
+
int nnzb, const cusparseMatDescr_t descrA, const scalar_t *bsrValA, \
|
| 267 |
+
const int *bsrRowPtrA, const int *bsrColIndA, int blockDim, \
|
| 268 |
+
bsrsm2Info_t info, cusparseSolvePolicy_t policy, void *pBuffer
|
| 269 |
+
|
| 270 |
+
template <typename scalar_t>
|
| 271 |
+
inline void bsrsm2_analysis(CUSPARSE_BSRSM2_ANALYSIS_ARGTYPES(scalar_t)) {
|
| 272 |
+
TORCH_INTERNAL_ASSERT(
|
| 273 |
+
false,
|
| 274 |
+
"at::cuda::sparse::bsrsm2_analysis: not implemented for ",
|
| 275 |
+
typeid(scalar_t).name());
|
| 276 |
+
}
|
| 277 |
+
|
| 278 |
+
template <>
|
| 279 |
+
void bsrsm2_analysis<float>(CUSPARSE_BSRSM2_ANALYSIS_ARGTYPES(float));
|
| 280 |
+
template <>
|
| 281 |
+
void bsrsm2_analysis<double>(CUSPARSE_BSRSM2_ANALYSIS_ARGTYPES(double));
|
| 282 |
+
template <>
|
| 283 |
+
void bsrsm2_analysis<c10::complex<float>>(
|
| 284 |
+
CUSPARSE_BSRSM2_ANALYSIS_ARGTYPES(c10::complex<float>));
|
| 285 |
+
template <>
|
| 286 |
+
void bsrsm2_analysis<c10::complex<double>>(
|
| 287 |
+
CUSPARSE_BSRSM2_ANALYSIS_ARGTYPES(c10::complex<double>));
|
| 288 |
+
|
| 289 |
+
#define CUSPARSE_BSRSM2_SOLVE_ARGTYPES(scalar_t) \
|
| 290 |
+
cusparseHandle_t handle, cusparseDirection_t dirA, \
|
| 291 |
+
cusparseOperation_t transA, cusparseOperation_t transX, int mb, int n, \
|
| 292 |
+
int nnzb, const scalar_t *alpha, const cusparseMatDescr_t descrA, \
|
| 293 |
+
const scalar_t *bsrValA, const int *bsrRowPtrA, const int *bsrColIndA, \
|
| 294 |
+
int blockDim, bsrsm2Info_t info, const scalar_t *B, int ldb, \
|
| 295 |
+
scalar_t *X, int ldx, cusparseSolvePolicy_t policy, void *pBuffer
|
| 296 |
+
|
| 297 |
+
template <typename scalar_t>
|
| 298 |
+
inline void bsrsm2_solve(CUSPARSE_BSRSM2_SOLVE_ARGTYPES(scalar_t)) {
|
| 299 |
+
TORCH_INTERNAL_ASSERT(
|
| 300 |
+
false,
|
| 301 |
+
"at::cuda::sparse::bsrsm2_solve: not implemented for ",
|
| 302 |
+
typeid(scalar_t).name());
|
| 303 |
+
}
|
| 304 |
+
|
| 305 |
+
template <>
|
| 306 |
+
void bsrsm2_solve<float>(CUSPARSE_BSRSM2_SOLVE_ARGTYPES(float));
|
| 307 |
+
template <>
|
| 308 |
+
void bsrsm2_solve<double>(CUSPARSE_BSRSM2_SOLVE_ARGTYPES(double));
|
| 309 |
+
template <>
|
| 310 |
+
void bsrsm2_solve<c10::complex<float>>(
|
| 311 |
+
CUSPARSE_BSRSM2_SOLVE_ARGTYPES(c10::complex<float>));
|
| 312 |
+
template <>
|
| 313 |
+
void bsrsm2_solve<c10::complex<double>>(
|
| 314 |
+
CUSPARSE_BSRSM2_SOLVE_ARGTYPES(c10::complex<double>));
|
| 315 |
+
|
| 316 |
+
#endif // AT_USE_HIPSPARSE_TRIANGULAR_SOLVE
|
| 317 |
+
|
| 318 |
+
} // namespace at::cuda::sparse
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/CUDATensorMethods.cuh
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/Tensor.h>
|
| 4 |
+
#include <c10/util/Half.h>
|
| 5 |
+
|
| 6 |
+
#include <cuda.h>
|
| 7 |
+
#include <cuda_runtime.h>
|
| 8 |
+
#include <cuda_fp16.h>
|
| 9 |
+
|
| 10 |
+
namespace at {
|
| 11 |
+
template <>
|
| 12 |
+
inline __half* Tensor::data() const {
|
| 13 |
+
return reinterpret_cast<__half*>(data<Half>());
|
| 14 |
+
}
|
| 15 |
+
} // namespace at
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/NumericLimits.cuh
ADDED
|
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <cuda.h>
|
| 4 |
+
#include <limits.h>
|
| 5 |
+
#include <math.h>
|
| 6 |
+
#include <float.h>
|
| 7 |
+
|
| 8 |
+
// NumericLimits.cuh is a holder for numeric limits definitions of commonly used
|
| 9 |
+
// types. This header is very specific to ROCm HIP and may be removed in the future.
|
| 10 |
+
// This header is derived from the legacy THCNumerics.cuh.
|
| 11 |
+
|
| 12 |
+
// The lower_bound and upper_bound constants are same as lowest and max for
|
| 13 |
+
// integral types, but are -inf and +inf for floating point types. They are
|
| 14 |
+
// useful in implementing min, max, etc.
|
| 15 |
+
|
| 16 |
+
namespace at {
|
| 17 |
+
|
| 18 |
+
template <typename T>
|
| 19 |
+
struct numeric_limits {
|
| 20 |
+
};
|
| 21 |
+
|
| 22 |
+
// WARNING: the following at::numeric_limits definitions are there only to support
|
| 23 |
+
// HIP compilation for the moment. Use std::numeric_limits if you are not
|
| 24 |
+
// compiling for ROCm.
|
| 25 |
+
// from @colesbury: "The functions on numeric_limits aren't marked with
|
| 26 |
+
// __device__ which is why they don't work with ROCm. CUDA allows them
|
| 27 |
+
// because they're constexpr."
|
| 28 |
+
|
| 29 |
+
namespace {
|
| 30 |
+
// ROCm doesn't like INFINITY too.
|
| 31 |
+
constexpr double inf = INFINITY;
|
| 32 |
+
}
|
| 33 |
+
|
| 34 |
+
template <>
|
| 35 |
+
struct numeric_limits<bool> {
|
| 36 |
+
static inline __host__ __device__ bool lowest() { return false; }
|
| 37 |
+
static inline __host__ __device__ bool max() { return true; }
|
| 38 |
+
static inline __host__ __device__ bool lower_bound() { return false; }
|
| 39 |
+
static inline __host__ __device__ bool upper_bound() { return true; }
|
| 40 |
+
};
|
| 41 |
+
|
| 42 |
+
template <>
|
| 43 |
+
struct numeric_limits<uint8_t> {
|
| 44 |
+
static inline __host__ __device__ uint8_t lowest() { return 0; }
|
| 45 |
+
static inline __host__ __device__ uint8_t max() { return UINT8_MAX; }
|
| 46 |
+
static inline __host__ __device__ uint8_t lower_bound() { return 0; }
|
| 47 |
+
static inline __host__ __device__ uint8_t upper_bound() { return UINT8_MAX; }
|
| 48 |
+
};
|
| 49 |
+
|
| 50 |
+
template <>
|
| 51 |
+
struct numeric_limits<int8_t> {
|
| 52 |
+
static inline __host__ __device__ int8_t lowest() { return INT8_MIN; }
|
| 53 |
+
static inline __host__ __device__ int8_t max() { return INT8_MAX; }
|
| 54 |
+
static inline __host__ __device__ int8_t lower_bound() { return INT8_MIN; }
|
| 55 |
+
static inline __host__ __device__ int8_t upper_bound() { return INT8_MAX; }
|
| 56 |
+
};
|
| 57 |
+
|
| 58 |
+
template <>
|
| 59 |
+
struct numeric_limits<int16_t> {
|
| 60 |
+
static inline __host__ __device__ int16_t lowest() { return INT16_MIN; }
|
| 61 |
+
static inline __host__ __device__ int16_t max() { return INT16_MAX; }
|
| 62 |
+
static inline __host__ __device__ int16_t lower_bound() { return INT16_MIN; }
|
| 63 |
+
static inline __host__ __device__ int16_t upper_bound() { return INT16_MAX; }
|
| 64 |
+
};
|
| 65 |
+
|
| 66 |
+
template <>
|
| 67 |
+
struct numeric_limits<int32_t> {
|
| 68 |
+
static inline __host__ __device__ int32_t lowest() { return INT32_MIN; }
|
| 69 |
+
static inline __host__ __device__ int32_t max() { return INT32_MAX; }
|
| 70 |
+
static inline __host__ __device__ int32_t lower_bound() { return INT32_MIN; }
|
| 71 |
+
static inline __host__ __device__ int32_t upper_bound() { return INT32_MAX; }
|
| 72 |
+
};
|
| 73 |
+
|
| 74 |
+
template <>
|
| 75 |
+
struct numeric_limits<int64_t> {
|
| 76 |
+
#ifdef _MSC_VER
|
| 77 |
+
static inline __host__ __device__ int64_t lowest() { return _I64_MIN; }
|
| 78 |
+
static inline __host__ __device__ int64_t max() { return _I64_MAX; }
|
| 79 |
+
static inline __host__ __device__ int64_t lower_bound() { return _I64_MIN; }
|
| 80 |
+
static inline __host__ __device__ int64_t upper_bound() { return _I64_MAX; }
|
| 81 |
+
#else
|
| 82 |
+
static inline __host__ __device__ int64_t lowest() { return INT64_MIN; }
|
| 83 |
+
static inline __host__ __device__ int64_t max() { return INT64_MAX; }
|
| 84 |
+
static inline __host__ __device__ int64_t lower_bound() { return INT64_MIN; }
|
| 85 |
+
static inline __host__ __device__ int64_t upper_bound() { return INT64_MAX; }
|
| 86 |
+
#endif
|
| 87 |
+
};
|
| 88 |
+
|
| 89 |
+
template <>
|
| 90 |
+
struct numeric_limits<at::Half> {
|
| 91 |
+
static inline __host__ __device__ at::Half lowest() { return at::Half(0xFBFF, at::Half::from_bits()); }
|
| 92 |
+
static inline __host__ __device__ at::Half max() { return at::Half(0x7BFF, at::Half::from_bits()); }
|
| 93 |
+
static inline __host__ __device__ at::Half lower_bound() { return at::Half(0xFC00, at::Half::from_bits()); }
|
| 94 |
+
static inline __host__ __device__ at::Half upper_bound() { return at::Half(0x7C00, at::Half::from_bits()); }
|
| 95 |
+
};
|
| 96 |
+
|
| 97 |
+
template <>
|
| 98 |
+
struct numeric_limits<at::BFloat16> {
|
| 99 |
+
static inline __host__ __device__ at::BFloat16 lowest() { return at::BFloat16(0xFF7F, at::BFloat16::from_bits()); }
|
| 100 |
+
static inline __host__ __device__ at::BFloat16 max() { return at::BFloat16(0x7F7F, at::BFloat16::from_bits()); }
|
| 101 |
+
static inline __host__ __device__ at::BFloat16 lower_bound() { return at::BFloat16(0xFF80, at::BFloat16::from_bits()); }
|
| 102 |
+
static inline __host__ __device__ at::BFloat16 upper_bound() { return at::BFloat16(0x7F80, at::BFloat16::from_bits()); }
|
| 103 |
+
};
|
| 104 |
+
|
| 105 |
+
template <>
|
| 106 |
+
struct numeric_limits<float> {
|
| 107 |
+
static inline __host__ __device__ float lowest() { return -FLT_MAX; }
|
| 108 |
+
static inline __host__ __device__ float max() { return FLT_MAX; }
|
| 109 |
+
static inline __host__ __device__ float lower_bound() { return -static_cast<float>(inf); }
|
| 110 |
+
static inline __host__ __device__ float upper_bound() { return static_cast<float>(inf); }
|
| 111 |
+
};
|
| 112 |
+
|
| 113 |
+
template <>
|
| 114 |
+
struct numeric_limits<double> {
|
| 115 |
+
static inline __host__ __device__ double lowest() { return -DBL_MAX; }
|
| 116 |
+
static inline __host__ __device__ double max() { return DBL_MAX; }
|
| 117 |
+
static inline __host__ __device__ double lower_bound() { return -inf; }
|
| 118 |
+
static inline __host__ __device__ double upper_bound() { return inf; }
|
| 119 |
+
};
|
| 120 |
+
|
| 121 |
+
} // namespace at
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/Sleep.h
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
#include <c10/macros/Export.h>
|
| 3 |
+
#include <cstdint>
|
| 4 |
+
|
| 5 |
+
namespace at::cuda {
|
| 6 |
+
|
| 7 |
+
// enqueues a kernel that spins for the specified number of cycles
|
| 8 |
+
TORCH_CUDA_CU_API void sleep(int64_t cycles);
|
| 9 |
+
|
| 10 |
+
} // namespace at::cuda
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/detail/CUDAHooks.h
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/detail/CUDAHooksInterface.h>
|
| 4 |
+
|
| 5 |
+
#include <ATen/Generator.h>
|
| 6 |
+
#include <c10/util/Optional.h>
|
| 7 |
+
|
| 8 |
+
// TODO: No need to have this whole header, we can just put it all in
|
| 9 |
+
// the cpp file
|
| 10 |
+
|
| 11 |
+
namespace at::cuda::detail {
|
| 12 |
+
|
| 13 |
+
// Set the callback to initialize Magma, which is set by
|
| 14 |
+
// torch_cuda_cu. This indirection is required so magma_init is called
|
| 15 |
+
// in the same library where Magma will be used.
|
| 16 |
+
TORCH_CUDA_CPP_API void set_magma_init_fn(void (*magma_init_fn)());
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
// The real implementation of CUDAHooksInterface
|
| 20 |
+
struct CUDAHooks : public at::CUDAHooksInterface {
|
| 21 |
+
CUDAHooks(at::CUDAHooksArgs) {}
|
| 22 |
+
void initCUDA() const override;
|
| 23 |
+
Device getDeviceFromPtr(void* data) const override;
|
| 24 |
+
bool isPinnedPtr(const void* data) const override;
|
| 25 |
+
const Generator& getDefaultCUDAGenerator(DeviceIndex device_index = -1) const override;
|
| 26 |
+
bool hasCUDA() const override;
|
| 27 |
+
bool hasMAGMA() const override;
|
| 28 |
+
bool hasCuDNN() const override;
|
| 29 |
+
bool hasCuSOLVER() const override;
|
| 30 |
+
bool hasROCM() const override;
|
| 31 |
+
const at::cuda::NVRTC& nvrtc() const override;
|
| 32 |
+
DeviceIndex current_device() const override;
|
| 33 |
+
bool hasPrimaryContext(DeviceIndex device_index) const override;
|
| 34 |
+
Allocator* getCUDADeviceAllocator() const override;
|
| 35 |
+
Allocator* getPinnedMemoryAllocator() const override;
|
| 36 |
+
bool compiledWithCuDNN() const override;
|
| 37 |
+
bool compiledWithMIOpen() const override;
|
| 38 |
+
bool supportsDilatedConvolutionWithCuDNN() const override;
|
| 39 |
+
bool supportsDepthwiseConvolutionWithCuDNN() const override;
|
| 40 |
+
bool supportsBFloat16ConvolutionWithCuDNNv8() const override;
|
| 41 |
+
bool hasCUDART() const override;
|
| 42 |
+
long versionCUDART() const override;
|
| 43 |
+
long versionCuDNN() const override;
|
| 44 |
+
std::string showConfig() const override;
|
| 45 |
+
double batchnormMinEpsilonCuDNN() const override;
|
| 46 |
+
int64_t cuFFTGetPlanCacheMaxSize(DeviceIndex device_index) const override;
|
| 47 |
+
void cuFFTSetPlanCacheMaxSize(DeviceIndex device_index, int64_t max_size) const override;
|
| 48 |
+
int64_t cuFFTGetPlanCacheSize(DeviceIndex device_index) const override;
|
| 49 |
+
void cuFFTClearPlanCache(DeviceIndex device_index) const override;
|
| 50 |
+
int getNumGPUs() const override;
|
| 51 |
+
void deviceSynchronize(DeviceIndex device_index) const override;
|
| 52 |
+
};
|
| 53 |
+
|
| 54 |
+
} // at::cuda::detail
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/detail/DeviceThreadHandles.h
ADDED
|
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Some stateful GPU libraries, such as cuDNN, cuBLAS, use handles to store states.
|
| 2 |
+
// These handles are tied to device, and these libraries requires/recommends not to
|
| 3 |
+
// share handles across host threads.
|
| 4 |
+
//
|
| 5 |
+
// These libraries recommend using one handle per host thread. We may not want to do
|
| 6 |
+
// this because threads are relatively light-weight, but creating and destroying
|
| 7 |
+
// handles is expensive (destroying the handle causes synchronizations). DataParallel,
|
| 8 |
+
// for example, creates new threads for each forward pass.
|
| 9 |
+
//
|
| 10 |
+
// This file implements a handle pool mechanism. The handle pool returns handles on
|
| 11 |
+
// demand as threads request them. If all existing handles in the pool are in use,
|
| 12 |
+
// it creates a new one. As threads terminate, they release handles back into the pool.
|
| 13 |
+
// In this way, the handle pool never creates more handles than the high-water mark of
|
| 14 |
+
// active threads, so it's efficient with DataParallel.
|
| 15 |
+
|
| 16 |
+
#pragma once
|
| 17 |
+
|
| 18 |
+
#include <unordered_map>
|
| 19 |
+
#include <vector>
|
| 20 |
+
#include <utility>
|
| 21 |
+
#include <mutex>
|
| 22 |
+
#include <memory>
|
| 23 |
+
|
| 24 |
+
#include <c10/util/Exception.h>
|
| 25 |
+
|
| 26 |
+
namespace at::cuda { namespace {
|
| 27 |
+
|
| 28 |
+
template <typename Handle_t, void Create(Handle_t *), void Destroy(Handle_t)>
|
| 29 |
+
struct DeviceThreadHandlePool : public std::enable_shared_from_this<DeviceThreadHandlePool<Handle_t, Create, Destroy>> {
|
| 30 |
+
|
| 31 |
+
struct Handle {
|
| 32 |
+
Handle_t handle;
|
| 33 |
+
Handle(bool create = false) : handle(nullptr)
|
| 34 |
+
{
|
| 35 |
+
if(create) Create(&handle);
|
| 36 |
+
}
|
| 37 |
+
// std::vector.emplace() and push_back() may route through temporaries and call
|
| 38 |
+
// copy/move constructors along the way. If this is the case, we don't want
|
| 39 |
+
// the destructors of temporaries to call cudnnDestroy on the handle.
|
| 40 |
+
// We can achieve safety (for the narrow case of stashing within std::vectors)
|
| 41 |
+
// by making Handle moveable but not copyable, and transferring handle ownership
|
| 42 |
+
// to the latest constructed object. This is not a substitute for full-blown
|
| 43 |
+
// reference counting, but reference counting may be overkill here.
|
| 44 |
+
// Another alternative is to wrap the saved Handles in unique_ptrs, i.e.,
|
| 45 |
+
// unordered_map<int, vector<unique_ptr<Handle>>> created_handles;
|
| 46 |
+
Handle(const Handle& rhs) = delete;
|
| 47 |
+
// Following https://stackoverflow.com/questions/3279543/what-is-the-copy-and-swap-idiom
|
| 48 |
+
Handle(Handle&& rhs) : Handle() { std::swap(handle, rhs.handle); }
|
| 49 |
+
// operator= takes argument by value
|
| 50 |
+
Handle& operator=(Handle rhs) { std::swap(handle, rhs.handle); return *this; }
|
| 51 |
+
~Handle() {
|
| 52 |
+
if(handle) Destroy(handle);
|
| 53 |
+
}
|
| 54 |
+
};
|
| 55 |
+
|
| 56 |
+
std::mutex mutex;
|
| 57 |
+
|
| 58 |
+
// Handles are lazily created as different threads request them,
|
| 59 |
+
// but are never destroyed until the end of the process.
|
| 60 |
+
// The maximum number of handles this process will create for each device is equal
|
| 61 |
+
// to the high-water mark of the number of concurrently active threads that request
|
| 62 |
+
// handles for that device.
|
| 63 |
+
// When threads terminate, they release their handles back into the pool for reuse.
|
| 64 |
+
// Otherwise, new handles would be created every time new threads were spawned,
|
| 65 |
+
// resulting in poor performance for Python modules that repeatedly or frequently
|
| 66 |
+
// spawned new sets of threads (like DataParallel, which creates a new set of threads
|
| 67 |
+
// for each forward pass).
|
| 68 |
+
//
|
| 69 |
+
// To prevent potential deadlocks, we explicitly choose not to cap the number
|
| 70 |
+
// of handles that are created per device.
|
| 71 |
+
// Example of danger: If we cap the max handles at 4, and 5 threads are sharing a device,
|
| 72 |
+
// only 4 can make forward progress at any time. The other 4 will not release their
|
| 73 |
+
// handles until they exit, so the fifth cannot make progress until then. This is
|
| 74 |
+
// not a problem...UNLESS all 5 threads attempt some sort of synchronization at an
|
| 75 |
+
// intermediate point (ie, before any of them have exited). We have no way to anticipate
|
| 76 |
+
// or enforce that user threads will not attempt such intermediate synchronization.
|
| 77 |
+
// The only way to ensure safety is to avoid imposing a cap on the number of handles.
|
| 78 |
+
std::unordered_map<int, std::vector<Handle>> created_handles;
|
| 79 |
+
std::unordered_map<int, std::vector<Handle_t>> available_handles;
|
| 80 |
+
|
| 81 |
+
// PoolWindow lazily creates and caches the handles that a particular thread is using,
|
| 82 |
+
// so in the common case handle access doesn't incur either handle creation or a mutex lock.
|
| 83 |
+
class PoolWindow
|
| 84 |
+
{
|
| 85 |
+
public:
|
| 86 |
+
PoolWindow(std::shared_ptr<DeviceThreadHandlePool> parent): weak_parent(std::move(parent)) {}
|
| 87 |
+
~PoolWindow(){ release(); }
|
| 88 |
+
|
| 89 |
+
Handle_t reserve(int device)
|
| 90 |
+
{
|
| 91 |
+
// If this thread already has a handle for this device, return it
|
| 92 |
+
if(my_handles.find(device) != my_handles.end())
|
| 93 |
+
return my_handles[device];
|
| 94 |
+
|
| 95 |
+
// otherwise, either grab a handle from the pool if one is available,
|
| 96 |
+
// or if not, create a new one.
|
| 97 |
+
auto parent = weak_parent.lock();
|
| 98 |
+
TORCH_CHECK(parent, "Cannot create handle during program termination");
|
| 99 |
+
std::lock_guard<std::mutex> guard(parent->mutex);
|
| 100 |
+
|
| 101 |
+
if(parent->available_handles[device].size() > 0)
|
| 102 |
+
{
|
| 103 |
+
my_handles[device] = parent->available_handles[device].back();
|
| 104 |
+
parent->available_handles[device].pop_back();
|
| 105 |
+
}
|
| 106 |
+
else
|
| 107 |
+
{
|
| 108 |
+
// In local testing, I do observe that emplace_back sometimes routes through temporaries
|
| 109 |
+
// that incur move-constructor and destructor calls. See comments in Handle above.
|
| 110 |
+
parent->created_handles[device].emplace_back(true /*create*/);
|
| 111 |
+
my_handles[device] = parent->created_handles[device].back().handle;
|
| 112 |
+
}
|
| 113 |
+
|
| 114 |
+
return my_handles[device];
|
| 115 |
+
}
|
| 116 |
+
|
| 117 |
+
private:
|
| 118 |
+
// Stores the per-device handles currently owned by this thread
|
| 119 |
+
std::unordered_map<int, Handle_t> my_handles;
|
| 120 |
+
|
| 121 |
+
std::weak_ptr<DeviceThreadHandlePool> weak_parent;
|
| 122 |
+
|
| 123 |
+
// Called by the destructor. Releases this thread's handles back into the pool.
|
| 124 |
+
void release() {
|
| 125 |
+
if(my_handles.size() > 0) {
|
| 126 |
+
auto parent = weak_parent.lock();
|
| 127 |
+
if (!parent) {
|
| 128 |
+
// If this thread exits after atexit handlers have completed, the
|
| 129 |
+
// cuda context itself may be invalid, so we must leak the handles.
|
| 130 |
+
return;
|
| 131 |
+
}
|
| 132 |
+
|
| 133 |
+
std::lock_guard<std::mutex> guard(parent->mutex);
|
| 134 |
+
for(auto d_h : my_handles)
|
| 135 |
+
parent->available_handles[d_h.first].push_back(d_h.second);
|
| 136 |
+
}
|
| 137 |
+
}
|
| 138 |
+
};
|
| 139 |
+
|
| 140 |
+
// Warning:
|
| 141 |
+
// If you want to change this function, be aware that this function will be called
|
| 142 |
+
// by multiple threads and there is no mutex guarding the call of this function, so
|
| 143 |
+
// make sure your implementation is thread-safe.
|
| 144 |
+
PoolWindow *newPoolWindow() {
|
| 145 |
+
// The returned pointer will be owned by a thread local variable
|
| 146 |
+
// so that different threads does not share the same PoolWindow.
|
| 147 |
+
return new PoolWindow(this->shared_from_this());
|
| 148 |
+
}
|
| 149 |
+
};
|
| 150 |
+
|
| 151 |
+
}} // namespace at::cuda::detail::<anonymous>
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/detail/IndexUtils.cuh
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/core/TensorBase.h>
|
| 4 |
+
#include <ATen/cuda/detail/TensorInfo.cuh>
|
| 5 |
+
#include <ATen/native/CanUse32BitIndexMath.h>
|
| 6 |
+
|
| 7 |
+
namespace at::cuda::detail {
|
| 8 |
+
|
| 9 |
+
TORCH_CUDA_CU_API bool maybeOverlappingIndices(const at::TensorBase &t);
|
| 10 |
+
using at::native::canUse32BitIndexMath;
|
| 11 |
+
|
| 12 |
+
template <typename scalar, typename IndexType>
|
| 13 |
+
TensorInfo<scalar, IndexType>
|
| 14 |
+
getTensorInfo(const at::TensorBase &t) {
|
| 15 |
+
IndexType sz[MAX_TENSORINFO_DIMS];
|
| 16 |
+
IndexType st[MAX_TENSORINFO_DIMS];
|
| 17 |
+
|
| 18 |
+
int dims = t.dim();
|
| 19 |
+
for (int i = 0; i < dims; ++i) {
|
| 20 |
+
sz[i] = t.size(i);
|
| 21 |
+
st[i] = t.stride(i);
|
| 22 |
+
}
|
| 23 |
+
|
| 24 |
+
scalar* data_ptr = nullptr;
|
| 25 |
+
|
| 26 |
+
if constexpr (std::is_const<scalar>::value) {
|
| 27 |
+
data_ptr = t.const_data_ptr<scalar>();
|
| 28 |
+
} else {
|
| 29 |
+
data_ptr = t.mutable_data_ptr<scalar>();
|
| 30 |
+
}
|
| 31 |
+
|
| 32 |
+
return TensorInfo<scalar, IndexType>(
|
| 33 |
+
data_ptr, dims, sz, st);
|
| 34 |
+
}
|
| 35 |
+
|
| 36 |
+
} // namespace at::cuda::detail
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/detail/IntegerDivider.cuh
ADDED
|
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <assert.h>
|
| 4 |
+
#if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__)
|
| 5 |
+
#include <cuda_runtime.h>
|
| 6 |
+
#endif
|
| 7 |
+
|
| 8 |
+
namespace at::cuda::detail {
|
| 9 |
+
|
| 10 |
+
// A utility class to implement integer division by multiplication, given a fixed
|
| 11 |
+
// divisor.
|
| 12 |
+
//
|
| 13 |
+
// WARNING: The fast divider algorithm is only implemented for unsigned int;
|
| 14 |
+
// otherwise we default to plain integer division. For unsigned int,
|
| 15 |
+
// we further assume that the dividend is at most INT32_MAX. Thus,
|
| 16 |
+
// IntDivider must NOT be used for general integer division.
|
| 17 |
+
//
|
| 18 |
+
// This reduced range is enough for our purpose, and it allows us to
|
| 19 |
+
// slightly simplify the computation.
|
| 20 |
+
//
|
| 21 |
+
// (NOTE: Below, "2^k" denotes exponentiation, i.e., 1<<k.)
|
| 22 |
+
//
|
| 23 |
+
// For any N-bit unsigned integer d (> 0), we can find a "magic number" m (2^N
|
| 24 |
+
// <= m < 2^(N+1)) and shift s such that:
|
| 25 |
+
//
|
| 26 |
+
// \floor(n / d) = \floor((m * n) / 2^(N+s)).
|
| 27 |
+
//
|
| 28 |
+
// Given such m and s, the integer division can be then implemented as:
|
| 29 |
+
//
|
| 30 |
+
// let m' = m - 2^N // 0 <= m' < 2^N
|
| 31 |
+
//
|
| 32 |
+
// fast_integer_division(n):
|
| 33 |
+
// // Multiply two N-bit unsigned integers: the result is a 2N-bit unsigned
|
| 34 |
+
// // integer. Then take the higher N bits.
|
| 35 |
+
// t = (m' * n) >> N
|
| 36 |
+
//
|
| 37 |
+
// // Here we use the fact that n is less than 2^(N-1): otherwise the value
|
| 38 |
+
// // of (t + n) may not fit in an N-bit integer.
|
| 39 |
+
// return (t + n) >> s
|
| 40 |
+
//
|
| 41 |
+
// Finding such a magic number is surprisingly easy:
|
| 42 |
+
//
|
| 43 |
+
// s = \ceil(\log_2 d)
|
| 44 |
+
// m' = \floor(2^N * (2^s - d) / d) + 1 // Need 2N-bit integer arithmetic.
|
| 45 |
+
//
|
| 46 |
+
// See also:
|
| 47 |
+
// - Division by Invariant Integers Using Multiplication,
|
| 48 |
+
// Torbjörn Granlund and Peter L. Montgomery, 1994.
|
| 49 |
+
//
|
| 50 |
+
// - http://www.hackersdelight.org/magic.htm
|
| 51 |
+
//
|
| 52 |
+
// - http://ridiculousfish.com/blog/posts/labor-of-division-episode-i.html
|
| 53 |
+
|
| 54 |
+
// Result of div/mod operation stored together.
|
| 55 |
+
template <typename Value>
|
| 56 |
+
struct DivMod {
|
| 57 |
+
Value div, mod;
|
| 58 |
+
|
| 59 |
+
C10_HOST_DEVICE DivMod(Value div, Value mod) : div(div), mod(mod) { }
|
| 60 |
+
};
|
| 61 |
+
|
| 62 |
+
// Base case: we only have an implementation for uint32_t for now. For
|
| 63 |
+
// everything else, we use plain division.
|
| 64 |
+
template <typename Value>
|
| 65 |
+
struct IntDivider {
|
| 66 |
+
IntDivider() = default;
|
| 67 |
+
IntDivider(Value d) : divisor(d) { }
|
| 68 |
+
|
| 69 |
+
C10_HOST_DEVICE inline Value div(Value n) const { return n / divisor; }
|
| 70 |
+
C10_HOST_DEVICE inline Value mod(Value n) const { return n % divisor; }
|
| 71 |
+
C10_HOST_DEVICE inline DivMod<Value> divmod(Value n) const {
|
| 72 |
+
return DivMod<Value>(n / divisor, n % divisor);
|
| 73 |
+
}
|
| 74 |
+
|
| 75 |
+
Value divisor;
|
| 76 |
+
};
|
| 77 |
+
|
| 78 |
+
// Implement fast integer division.
|
| 79 |
+
template <>
|
| 80 |
+
struct IntDivider<unsigned int> {
|
| 81 |
+
static_assert(sizeof(unsigned int) == 4, "Assumes 32-bit unsigned int.");
|
| 82 |
+
|
| 83 |
+
IntDivider() = default;
|
| 84 |
+
|
| 85 |
+
IntDivider(unsigned int d) : divisor(d) {
|
| 86 |
+
assert(divisor >= 1 && divisor <= INT32_MAX);
|
| 87 |
+
|
| 88 |
+
// TODO: gcc/clang has __builtin_clz() but it's not portable.
|
| 89 |
+
for (shift = 0; shift < 32; shift++) if ((1U << shift) >= divisor) break;
|
| 90 |
+
|
| 91 |
+
uint64_t one = 1;
|
| 92 |
+
uint64_t magic = ((one << 32) * ((one << shift) - divisor)) / divisor + 1;
|
| 93 |
+
m1 = magic;
|
| 94 |
+
assert(m1 > 0 && m1 == magic); // m1 must fit in 32 bits.
|
| 95 |
+
}
|
| 96 |
+
|
| 97 |
+
C10_HOST_DEVICE inline unsigned int div(unsigned int n) const {
|
| 98 |
+
#if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__)
|
| 99 |
+
// 't' is the higher 32-bits of unsigned 32-bit multiplication of 'n' and
|
| 100 |
+
// 'm1'.
|
| 101 |
+
unsigned int t = __umulhi(n, m1);
|
| 102 |
+
return (t + n) >> shift;
|
| 103 |
+
#else
|
| 104 |
+
// Using uint64_t so that the addition does not overflow.
|
| 105 |
+
uint64_t t = ((uint64_t) n * m1) >> 32;
|
| 106 |
+
return (t + n) >> shift;
|
| 107 |
+
#endif
|
| 108 |
+
}
|
| 109 |
+
|
| 110 |
+
C10_HOST_DEVICE inline unsigned int mod(unsigned int n) const {
|
| 111 |
+
return n - div(n) * divisor;
|
| 112 |
+
}
|
| 113 |
+
|
| 114 |
+
C10_HOST_DEVICE inline DivMod<unsigned int> divmod(unsigned int n) const {
|
| 115 |
+
unsigned int q = div(n);
|
| 116 |
+
return DivMod<unsigned int>(q, n - q * divisor);
|
| 117 |
+
}
|
| 118 |
+
|
| 119 |
+
unsigned int divisor; // d above.
|
| 120 |
+
unsigned int m1; // Magic number: m' above.
|
| 121 |
+
unsigned int shift; // Shift amounts.
|
| 122 |
+
};
|
| 123 |
+
|
| 124 |
+
} // namespace at::cuda::detail
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/detail/OffsetCalculator.cuh
ADDED
|
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <array>
|
| 4 |
+
#include <cstdint>
|
| 5 |
+
#include <type_traits>
|
| 6 |
+
#include <c10/macros/Macros.h>
|
| 7 |
+
#include <ATen/core/Array.h>
|
| 8 |
+
#include <ATen/native/TensorIterator.h>
|
| 9 |
+
#include <ATen/cuda/detail/IntegerDivider.cuh>
|
| 10 |
+
|
| 11 |
+
// If element_sizes is nullptr, then the strides will be in bytes, otherwise
|
| 12 |
+
// the strides will be in # of elements.
|
| 13 |
+
// Operands that share the same shape, but may have different strides.
|
| 14 |
+
// OffsetCalculator iterates the tensor in a column-major order
|
| 15 |
+
|
| 16 |
+
#if defined(USE_ROCM)
|
| 17 |
+
constexpr int MAX_DIMS = 16;
|
| 18 |
+
#else
|
| 19 |
+
constexpr int MAX_DIMS = 25;
|
| 20 |
+
#endif
|
| 21 |
+
|
| 22 |
+
template <int NARGS, typename index_t = uint32_t, bool signed_strides = false>
|
| 23 |
+
struct OffsetCalculator {
|
| 24 |
+
// We allow having negative strides to implement some operations like torch.flip
|
| 25 |
+
using stride_t = std::conditional_t<signed_strides,
|
| 26 |
+
std::make_signed_t<index_t>,
|
| 27 |
+
index_t>;
|
| 28 |
+
// The offset for each argument. Wrapper around fixed-size array.
|
| 29 |
+
// On CUDA, zero sized array is not allowed, so when we are handling nullary
|
| 30 |
+
// operators, we need to create a size 1 offset to avoid compiler failure.
|
| 31 |
+
// This size 1 offset is just a placeholder, and we will not use it.
|
| 32 |
+
using offset_type = at::detail::Array<stride_t, std::max<int>(NARGS, 1)>;
|
| 33 |
+
|
| 34 |
+
// if element_sizes is nullptr, then the strides will be in bytes, otherwise
|
| 35 |
+
// the strides will be in # of elements.
|
| 36 |
+
OffsetCalculator(int dims, const int64_t* sizes, const int64_t* const* strides, const int64_t* element_sizes=nullptr) : dims(dims) {
|
| 37 |
+
TORCH_CHECK(dims <= MAX_DIMS, "tensor has too many (>", MAX_DIMS, ") dims");
|
| 38 |
+
for (int i=0; i < dims; i++){
|
| 39 |
+
sizes_[i] = at::cuda::detail::IntDivider<index_t>(sizes[i]);
|
| 40 |
+
for (int arg = 0; arg < NARGS; arg++) {
|
| 41 |
+
int64_t element_size = (element_sizes == nullptr ? 1LL : element_sizes[arg]);
|
| 42 |
+
strides_[i][arg] = strides[arg][i] / element_size;
|
| 43 |
+
}
|
| 44 |
+
}
|
| 45 |
+
}
|
| 46 |
+
|
| 47 |
+
C10_HOST_DEVICE offset_type get(index_t linear_idx) const {
|
| 48 |
+
offset_type offsets;
|
| 49 |
+
#pragma unroll
|
| 50 |
+
for (int arg = 0; arg < NARGS; arg++) {
|
| 51 |
+
offsets[arg] = 0;
|
| 52 |
+
}
|
| 53 |
+
|
| 54 |
+
#pragma unroll
|
| 55 |
+
for (int dim = 0; dim < MAX_DIMS; ++dim) {
|
| 56 |
+
if (dim == dims) {
|
| 57 |
+
break;
|
| 58 |
+
}
|
| 59 |
+
auto divmod = sizes_[dim].divmod(linear_idx);
|
| 60 |
+
linear_idx = divmod.div;
|
| 61 |
+
|
| 62 |
+
#pragma unroll
|
| 63 |
+
for (int arg = 0; arg < NARGS; arg++) {
|
| 64 |
+
offsets[arg] += divmod.mod * strides_[dim][arg];
|
| 65 |
+
}
|
| 66 |
+
|
| 67 |
+
}
|
| 68 |
+
return offsets;
|
| 69 |
+
}
|
| 70 |
+
|
| 71 |
+
int dims;
|
| 72 |
+
at::cuda::detail::IntDivider<index_t> sizes_[MAX_DIMS];
|
| 73 |
+
stride_t strides_[MAX_DIMS][std::max<int>(NARGS, 1)];
|
| 74 |
+
};
|
| 75 |
+
|
| 76 |
+
template <int NARGS, typename index_t = uint32_t>
|
| 77 |
+
struct TrivialOffsetCalculator {
|
| 78 |
+
// The offset for each argument. Wrapper around fixed-size array.
|
| 79 |
+
// The offsets are in # of elements, not in bytes.
|
| 80 |
+
// On CUDA, zero sized array is not allowed, so when we are handling nullary
|
| 81 |
+
// operators, we need to create a size 1 offset to avoid compiler failure.
|
| 82 |
+
// This size 1 offset is just a placeholder, and we will not use it.
|
| 83 |
+
using offset_type = at::detail::Array<index_t, std::max<int>(NARGS, 1)>;
|
| 84 |
+
|
| 85 |
+
C10_HOST_DEVICE offset_type get(index_t linear_idx) const {
|
| 86 |
+
offset_type offsets;
|
| 87 |
+
#pragma unroll
|
| 88 |
+
for (int arg = 0; arg < NARGS; arg++) {
|
| 89 |
+
offsets[arg] = linear_idx;
|
| 90 |
+
}
|
| 91 |
+
return offsets;
|
| 92 |
+
}
|
| 93 |
+
};
|
| 94 |
+
|
| 95 |
+
// Make an OffsetCalculator with byte offsets
|
| 96 |
+
template<int N, bool signed_strides = false>
|
| 97 |
+
static OffsetCalculator<N, uint32_t, signed_strides> make_offset_calculator(const at::TensorIteratorBase& iter) {
|
| 98 |
+
TORCH_INTERNAL_ASSERT(N <= iter.ntensors());
|
| 99 |
+
std::array<const int64_t*, N> strides;
|
| 100 |
+
for (int i = 0; i < N; i++) {
|
| 101 |
+
strides[i] = iter.strides(i).data();
|
| 102 |
+
}
|
| 103 |
+
return OffsetCalculator<N, uint32_t, signed_strides>(iter.ndim(), iter.shape().data(), strides.data());
|
| 104 |
+
}
|
| 105 |
+
|
| 106 |
+
// Make an OffsetCalculator with element offsets
|
| 107 |
+
template<int N, bool signed_strides = false>
|
| 108 |
+
static OffsetCalculator<N, uint32_t, signed_strides> make_element_offset_calculator(
|
| 109 |
+
const at::TensorIteratorBase& iter) {
|
| 110 |
+
TORCH_INTERNAL_ASSERT(N <= iter.ntensors());
|
| 111 |
+
std::array<const int64_t*, N> strides;
|
| 112 |
+
std::array<int64_t, N> element_sizes;
|
| 113 |
+
for (int i = 0; i < N; i++) {
|
| 114 |
+
strides[i] = iter.strides(i).data();
|
| 115 |
+
element_sizes[i] = iter.element_size(i);
|
| 116 |
+
}
|
| 117 |
+
return OffsetCalculator<N, uint32_t, signed_strides>(
|
| 118 |
+
iter.ndim(), iter.shape().data(), strides.data(), element_sizes.data());
|
| 119 |
+
}
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/detail/PhiloxCudaStateRaw.cuh
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// No "#pragma once" because this is a raw definition that can be copied by jit codegen.
|
| 2 |
+
// Eager mode clients should not include this file directly, instead,
|
| 3 |
+
// they should #include <ATen/cuda/PhiloxCudaState.h>, which has a #pragma once.
|
| 4 |
+
|
| 5 |
+
// Stores RNG state values. Passed as a kernel argument.
|
| 6 |
+
// See Note [CUDA Graph-safe RNG states].
|
| 7 |
+
//
|
| 8 |
+
// The raw definition lives in its own file so jit codegen can easily copy it.
|
| 9 |
+
namespace at {
|
| 10 |
+
|
| 11 |
+
struct PhiloxCudaState {
|
| 12 |
+
PhiloxCudaState() = default;
|
| 13 |
+
// Called if graph capture is not underway
|
| 14 |
+
PhiloxCudaState(uint64_t seed,
|
| 15 |
+
uint64_t offset) {
|
| 16 |
+
seed_.val = seed;
|
| 17 |
+
offset_.val = offset;
|
| 18 |
+
}
|
| 19 |
+
// Called if graph capture is underway
|
| 20 |
+
PhiloxCudaState(int64_t* seed,
|
| 21 |
+
int64_t* offset_extragraph,
|
| 22 |
+
uint32_t offset_intragraph) {
|
| 23 |
+
seed_.ptr = seed;
|
| 24 |
+
offset_.ptr = offset_extragraph;
|
| 25 |
+
offset_intragraph_ = offset_intragraph;
|
| 26 |
+
captured_ = true;
|
| 27 |
+
}
|
| 28 |
+
|
| 29 |
+
// Public members, directly accessible by at::cuda::philox::unpack.
|
| 30 |
+
// If we made them private with getters/setters, the getters/setters
|
| 31 |
+
// would have to be __device__, and we can't declare __device__ in ATen.
|
| 32 |
+
union Payload {
|
| 33 |
+
uint64_t val;
|
| 34 |
+
int64_t* ptr;
|
| 35 |
+
};
|
| 36 |
+
|
| 37 |
+
Payload seed_;
|
| 38 |
+
Payload offset_;
|
| 39 |
+
uint32_t offset_intragraph_ = 0;
|
| 40 |
+
bool captured_ = false;
|
| 41 |
+
};
|
| 42 |
+
|
| 43 |
+
} // namespace at
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/detail/TensorInfo.cuh
ADDED
|
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/CollapseDims.h>
|
| 4 |
+
|
| 5 |
+
namespace at::cuda::detail {
|
| 6 |
+
|
| 7 |
+
#define MAX_TENSORINFO_DIMS 25
|
| 8 |
+
|
| 9 |
+
// CUDA kernel argument that defines tensor layout
|
| 10 |
+
template <typename T, typename IndexType>
|
| 11 |
+
struct TensorInfo {
|
| 12 |
+
TensorInfo();
|
| 13 |
+
TensorInfo(T* p,
|
| 14 |
+
int dim,
|
| 15 |
+
IndexType sz[MAX_TENSORINFO_DIMS],
|
| 16 |
+
IndexType st[MAX_TENSORINFO_DIMS]);
|
| 17 |
+
|
| 18 |
+
// Set the size of the given dimension to 1, as if it were a
|
| 19 |
+
// reduction dim (allows you to calculate offsets of the reduction
|
| 20 |
+
// slice)
|
| 21 |
+
void reduceDim(int dim);
|
| 22 |
+
|
| 23 |
+
// See note on [collapse dims].
|
| 24 |
+
int collapseDims(const int excludeDim = -1);
|
| 25 |
+
|
| 26 |
+
// Contiguous tensors of more than one dimension are collapsed down
|
| 27 |
+
// to one tensor
|
| 28 |
+
__host__ __device__ inline bool isContiguous() const {
|
| 29 |
+
return (dims == 1 && strides[0] == 1);
|
| 30 |
+
}
|
| 31 |
+
|
| 32 |
+
T* data;
|
| 33 |
+
IndexType sizes[MAX_TENSORINFO_DIMS];
|
| 34 |
+
IndexType strides[MAX_TENSORINFO_DIMS];
|
| 35 |
+
int dims;
|
| 36 |
+
};
|
| 37 |
+
|
| 38 |
+
template <typename T, typename IndexType>
|
| 39 |
+
TensorInfo<T, IndexType>::TensorInfo() {
|
| 40 |
+
data = nullptr;
|
| 41 |
+
dims = 0;
|
| 42 |
+
}
|
| 43 |
+
|
| 44 |
+
template <typename T, typename IndexType>
|
| 45 |
+
TensorInfo<T, IndexType>::TensorInfo(T* p,
|
| 46 |
+
int dim,
|
| 47 |
+
IndexType sz[MAX_TENSORINFO_DIMS],
|
| 48 |
+
IndexType st[MAX_TENSORINFO_DIMS]) {
|
| 49 |
+
data = p;
|
| 50 |
+
dims = dim;
|
| 51 |
+
TORCH_CHECK(dims < MAX_TENSORINFO_DIMS, "CUDA Tensors cannot have more than 25 dimensions");
|
| 52 |
+
|
| 53 |
+
for (int i = 0; i < dim; ++i) {
|
| 54 |
+
sizes[i] = sz[i];
|
| 55 |
+
strides[i] = st[i];
|
| 56 |
+
}
|
| 57 |
+
}
|
| 58 |
+
|
| 59 |
+
template <typename T, typename IndexType>
|
| 60 |
+
void
|
| 61 |
+
TensorInfo<T, IndexType>::reduceDim(int dim) {
|
| 62 |
+
TORCH_CHECK(dim < dims && dim >= 0, "expected dim between 0 and dims - 1");
|
| 63 |
+
sizes[dim] = 1;
|
| 64 |
+
}
|
| 65 |
+
|
| 66 |
+
template <typename T, typename IndexType>
|
| 67 |
+
int
|
| 68 |
+
TensorInfo<T, IndexType>::collapseDims(const int excludeDim) {
|
| 69 |
+
auto result = at::collapse_dims(sizes, strides, dims, excludeDim);
|
| 70 |
+
dims = std::get<1>(result);
|
| 71 |
+
return std::get<0>(result);
|
| 72 |
+
}
|
| 73 |
+
|
| 74 |
+
// Translate a linear index for the apply to a T* offset;
|
| 75 |
+
// specialized on `Dims` to reduce nvcc compilation time
|
| 76 |
+
template <typename T, typename IndexType, int Dims>
|
| 77 |
+
struct IndexToOffset {
|
| 78 |
+
static __host__ __device__ IndexType get(
|
| 79 |
+
IndexType linearId,
|
| 80 |
+
const TensorInfo<T, IndexType>& info) {
|
| 81 |
+
|
| 82 |
+
IndexType offset = 0;
|
| 83 |
+
|
| 84 |
+
// Uses static dims
|
| 85 |
+
for (int i = Dims - 1; i > 0; --i) {
|
| 86 |
+
IndexType curDimIndex = linearId % info.sizes[i];
|
| 87 |
+
IndexType curDimOffset = curDimIndex * info.strides[i];
|
| 88 |
+
offset += curDimOffset;
|
| 89 |
+
linearId /= info.sizes[i];
|
| 90 |
+
}
|
| 91 |
+
|
| 92 |
+
return offset + linearId * info.strides[0];
|
| 93 |
+
}
|
| 94 |
+
};
|
| 95 |
+
|
| 96 |
+
// Uses dynamic (runtime) instead of static (compiletime) dims
|
| 97 |
+
template <typename T, typename IndexType>
|
| 98 |
+
struct IndexToOffset<T, IndexType, -1> {
|
| 99 |
+
static inline __host__ __device__ IndexType get(
|
| 100 |
+
IndexType linearId,
|
| 101 |
+
const TensorInfo<T, IndexType>& info) {
|
| 102 |
+
|
| 103 |
+
IndexType offset = 0;
|
| 104 |
+
|
| 105 |
+
for (int i = info.dims - 1; i > 0; --i) {
|
| 106 |
+
IndexType curDimIndex = linearId % info.sizes[i];
|
| 107 |
+
IndexType curDimOffset = curDimIndex * info.strides[i];
|
| 108 |
+
offset += curDimOffset;
|
| 109 |
+
linearId /= info.sizes[i];
|
| 110 |
+
}
|
| 111 |
+
|
| 112 |
+
return offset + linearId * info.strides[0];
|
| 113 |
+
}
|
| 114 |
+
};
|
| 115 |
+
|
| 116 |
+
} // namespace at::cuda::detail
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/jiterator.h
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
#include <ATen/jit_macros.h>
|
| 3 |
+
|
| 4 |
+
#if AT_USE_JITERATOR()
|
| 5 |
+
|
| 6 |
+
#include <c10/macros/Export.h>
|
| 7 |
+
#include <c10/util/SmallVector.h>
|
| 8 |
+
#include <ATen/core/Tensor.h>
|
| 9 |
+
|
| 10 |
+
#include <string>
|
| 11 |
+
#include <vector>
|
| 12 |
+
|
| 13 |
+
namespace at::cuda {
|
| 14 |
+
|
| 15 |
+
TORCH_CUDA_CPP_API c10::SmallVector<at::Tensor> CompileAndLaunchKernel(
|
| 16 |
+
const std::string& code_string,
|
| 17 |
+
const std::string& kernel_name,
|
| 18 |
+
const int num_outputs,
|
| 19 |
+
const c10::SmallVector<at::Tensor>& tensors,
|
| 20 |
+
const c10::SmallVector<at::Scalar>& extra_args,
|
| 21 |
+
bool return_by_ref);
|
| 22 |
+
|
| 23 |
+
} // namespace at::cuda
|
| 24 |
+
|
| 25 |
+
#else
|
| 26 |
+
|
| 27 |
+
namespace at::cuda {
|
| 28 |
+
|
| 29 |
+
TORCH_CUDA_CPP_API c10::SmallVector<at::Tensor> CompileAndLaunchKernel(
|
| 30 |
+
const std::string& code_string,
|
| 31 |
+
const std::string& kernel_name,
|
| 32 |
+
const int num_outputs,
|
| 33 |
+
const c10::SmallVector<at::Tensor>& tensors,
|
| 34 |
+
const c10::SmallVector<at::Scalar>& extra_args,
|
| 35 |
+
bool return_by_ref) {
|
| 36 |
+
TORCH_CHECK(false, "Jiterator is not supported");
|
| 37 |
+
}
|
| 38 |
+
} // namespace at::cuda
|
| 39 |
+
|
| 40 |
+
#endif // AT_USE_JITERATOR()
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/tunable/GemmCommon.h
ADDED
|
@@ -0,0 +1,174 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Original TunableOp is from onnxruntime.
|
| 2 |
+
// https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/core/framework/tunable.h
|
| 3 |
+
// https://github.com/microsoft/onnxruntime/tree/main/onnxruntime/core/providers/rocm/tunable
|
| 4 |
+
// Copyright (c) Microsoft Corporation.
|
| 5 |
+
// Licensed under the MIT license.
|
| 6 |
+
//
|
| 7 |
+
// Adapting TunableOp into PyTorch
|
| 8 |
+
// Copyright (c) Advanced Micro Devices, Inc.
|
| 9 |
+
//
|
| 10 |
+
#pragma once
|
| 11 |
+
|
| 12 |
+
#include <string>
|
| 13 |
+
|
| 14 |
+
#include <ATen/cuda/tunable/TunableOp.h>
|
| 15 |
+
#include <ATen/cuda/Exceptions.h>
|
| 16 |
+
#include <c10/util/StringUtil.h>
|
| 17 |
+
|
| 18 |
+
namespace at::cuda::tunable {
|
| 19 |
+
|
| 20 |
+
enum class BlasOp {
|
| 21 |
+
N = 0,
|
| 22 |
+
T = 1
|
| 23 |
+
};
|
| 24 |
+
|
| 25 |
+
inline std::string BlasOpToString(BlasOp op) {
|
| 26 |
+
switch (op) {
|
| 27 |
+
case BlasOp::N:
|
| 28 |
+
return "N";
|
| 29 |
+
case BlasOp::T:
|
| 30 |
+
return "T";
|
| 31 |
+
}
|
| 32 |
+
TORCH_CHECK(false, "unrecognized BlasOp");
|
| 33 |
+
return "N";
|
| 34 |
+
}
|
| 35 |
+
|
| 36 |
+
template <typename T>
|
| 37 |
+
struct GemmParams : OpParams {
|
| 38 |
+
std::string Signature() const override {
|
| 39 |
+
return c10::str(transa, transb, "_", m, "_", n, "_", k);
|
| 40 |
+
}
|
| 41 |
+
|
| 42 |
+
GemmParams* DeepCopy() const {
|
| 43 |
+
GemmParams* copy = new GemmParams;
|
| 44 |
+
*copy = *this;
|
| 45 |
+
c10::DeviceIndex device = 0;
|
| 46 |
+
AT_CUDA_CHECK(c10::cuda::GetDevice(&device));
|
| 47 |
+
size_t c_size = m * n * sizeof(T);
|
| 48 |
+
copy->c = static_cast<T*>(c10::cuda::CUDACachingAllocator::raw_alloc(c_size));
|
| 49 |
+
AT_CUDA_CHECK(c10::cuda::CUDACachingAllocator::memcpyAsync(
|
| 50 |
+
copy->c, device, c, device, c_size, getCurrentCUDAStream(device), true));
|
| 51 |
+
return copy;
|
| 52 |
+
}
|
| 53 |
+
|
| 54 |
+
// only call on object returned by DeepCopy
|
| 55 |
+
void Delete() {
|
| 56 |
+
c10::cuda::CUDACachingAllocator::raw_delete(c);
|
| 57 |
+
}
|
| 58 |
+
|
| 59 |
+
TuningStatus NumericalCheck(GemmParams<T> *other) {
|
| 60 |
+
auto options = at::TensorOptions().dtype(c10::CppTypeToScalarType<T>::value).device(at::kCUDA);
|
| 61 |
+
// comparison done as 1D tensor
|
| 62 |
+
at::Tensor ref = at::from_blob(c, {m*n}, options);
|
| 63 |
+
at::Tensor oth = at::from_blob(other->c, {m*n}, options);
|
| 64 |
+
at::Tensor ref_float = ref.to(at::kFloat);
|
| 65 |
+
at::Tensor oth_float = oth.to(at::kFloat);
|
| 66 |
+
std::vector<double> atols{1e-1, 1e-2, 1e-3, 1e-4, 1e-5};
|
| 67 |
+
std::vector<double> rtols{1e-1, 1e-2, 1e-3, 1e-4, 1e-5};
|
| 68 |
+
double last_succeed_atol = 1;
|
| 69 |
+
double last_succeed_rtol = 1;
|
| 70 |
+
for (auto& atol : atols) {
|
| 71 |
+
for (auto& rtol : rtols) {
|
| 72 |
+
if (at::allclose(ref_float, oth_float, rtol, atol)) {
|
| 73 |
+
last_succeed_atol = atol;
|
| 74 |
+
last_succeed_rtol = rtol;
|
| 75 |
+
}
|
| 76 |
+
}
|
| 77 |
+
}
|
| 78 |
+
if (last_succeed_atol == 1) {
|
| 79 |
+
return FAIL;
|
| 80 |
+
}
|
| 81 |
+
else {
|
| 82 |
+
TUNABLE_LOG("├──verify numerics: atol=", last_succeed_atol, ", rtol=", last_succeed_rtol);
|
| 83 |
+
}
|
| 84 |
+
|
| 85 |
+
return OK;
|
| 86 |
+
}
|
| 87 |
+
|
| 88 |
+
char transa;
|
| 89 |
+
char transb;
|
| 90 |
+
int64_t m;
|
| 91 |
+
int64_t n;
|
| 92 |
+
int64_t k;
|
| 93 |
+
at::opmath_type<T> alpha;
|
| 94 |
+
const T* a;
|
| 95 |
+
int64_t lda;
|
| 96 |
+
const T* b;
|
| 97 |
+
int64_t ldb;
|
| 98 |
+
at::opmath_type<T> beta;
|
| 99 |
+
T* c;
|
| 100 |
+
int64_t ldc;
|
| 101 |
+
};
|
| 102 |
+
|
| 103 |
+
template <typename T>
|
| 104 |
+
struct GemmStridedBatchedParams : OpParams {
|
| 105 |
+
std::string Signature() const override {
|
| 106 |
+
return c10::str(transa, transb, "_", m, "_", n, "_", k, "_B_", batch);
|
| 107 |
+
}
|
| 108 |
+
|
| 109 |
+
GemmStridedBatchedParams* DeepCopy() const {
|
| 110 |
+
GemmStridedBatchedParams* copy = new GemmStridedBatchedParams;
|
| 111 |
+
*copy = *this;
|
| 112 |
+
c10::DeviceIndex device = 0;
|
| 113 |
+
AT_CUDA_CHECK(c10::cuda::GetDevice(&device));
|
| 114 |
+
size_t c_size = batch * stride_c * sizeof(T);
|
| 115 |
+
copy->c = static_cast<T*>(c10::cuda::CUDACachingAllocator::raw_alloc(c_size));
|
| 116 |
+
AT_CUDA_CHECK(c10::cuda::CUDACachingAllocator::memcpyAsync(
|
| 117 |
+
copy->c, device, c, device, c_size, getCurrentCUDAStream(device), true));
|
| 118 |
+
return copy;
|
| 119 |
+
}
|
| 120 |
+
|
| 121 |
+
// only call on object returned by DeepCopy
|
| 122 |
+
void Delete() {
|
| 123 |
+
c10::cuda::CUDACachingAllocator::raw_delete(c);
|
| 124 |
+
}
|
| 125 |
+
|
| 126 |
+
TuningStatus NumericalCheck(GemmStridedBatchedParams<T> *other) {
|
| 127 |
+
auto options = at::TensorOptions().dtype(c10::CppTypeToScalarType<T>::value).device(at::kCUDA);
|
| 128 |
+
// comparison done as 1D tensor
|
| 129 |
+
at::Tensor ref = at::from_blob(c, {batch*stride_c}, options);
|
| 130 |
+
at::Tensor oth = at::from_blob(other->c, {batch*stride_c}, options);
|
| 131 |
+
at::Tensor ref_float = ref.to(at::kFloat);
|
| 132 |
+
at::Tensor oth_float = oth.to(at::kFloat);
|
| 133 |
+
std::vector<double> atols{1e-1, 1e-2, 1e-3, 1e-4, 1e-5};
|
| 134 |
+
std::vector<double> rtols{1e-1, 1e-2, 1e-3, 1e-4, 1e-5};
|
| 135 |
+
double last_succeed_atol = 1;
|
| 136 |
+
double last_succeed_rtol = 1;
|
| 137 |
+
for (auto& atol : atols) {
|
| 138 |
+
for (auto& rtol : rtols) {
|
| 139 |
+
if (at::allclose(ref_float, oth_float, rtol, atol)) {
|
| 140 |
+
last_succeed_atol = atol;
|
| 141 |
+
last_succeed_rtol = rtol;
|
| 142 |
+
}
|
| 143 |
+
}
|
| 144 |
+
}
|
| 145 |
+
if (last_succeed_atol == 1) {
|
| 146 |
+
return FAIL;
|
| 147 |
+
}
|
| 148 |
+
else {
|
| 149 |
+
TUNABLE_LOG("├──verify numerics: atol=", last_succeed_atol, ", rtol=", last_succeed_rtol);
|
| 150 |
+
}
|
| 151 |
+
|
| 152 |
+
return OK;
|
| 153 |
+
}
|
| 154 |
+
|
| 155 |
+
char transa;
|
| 156 |
+
char transb;
|
| 157 |
+
int64_t m;
|
| 158 |
+
int64_t n;
|
| 159 |
+
int64_t k;
|
| 160 |
+
at::opmath_type<T> alpha;
|
| 161 |
+
const T* a;
|
| 162 |
+
int64_t lda;
|
| 163 |
+
int64_t stride_a;
|
| 164 |
+
const T* b;
|
| 165 |
+
int64_t ldb;
|
| 166 |
+
int64_t stride_b;
|
| 167 |
+
at::opmath_type<T> beta;
|
| 168 |
+
T* c;
|
| 169 |
+
int64_t ldc;
|
| 170 |
+
int64_t stride_c;
|
| 171 |
+
int64_t batch;
|
| 172 |
+
};
|
| 173 |
+
|
| 174 |
+
} // namespace at::cuda::tunable
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/tunable/GemmHipblaslt.h
ADDED
|
@@ -0,0 +1,379 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Copyright (c) Microsoft Corporation. All rights reserved.
|
| 2 |
+
// Licensed under the MIT License.
|
| 3 |
+
|
| 4 |
+
#pragma once
|
| 5 |
+
|
| 6 |
+
#include <ATen/cuda/CUDAContext.h>
|
| 7 |
+
#include <ATen/cuda/tunable/TunableOp.h>
|
| 8 |
+
#include <ATen/cuda/tunable/GemmCommon.h>
|
| 9 |
+
#include <c10/cuda/CUDACachingAllocator.h>
|
| 10 |
+
#include <c10/util/StringUtil.h>
|
| 11 |
+
|
| 12 |
+
#include <hipblaslt/hipblaslt.h>
|
| 13 |
+
#include <hipblaslt/hipblaslt-ext.hpp>
|
| 14 |
+
|
| 15 |
+
#define TORCH_HIPBLASLT_CHECK(EXPR) \
|
| 16 |
+
do { \
|
| 17 |
+
hipblasStatus_t __err = EXPR; \
|
| 18 |
+
TORCH_CHECK(__err == HIPBLAS_STATUS_SUCCESS, \
|
| 19 |
+
"hipblaslt error: ", \
|
| 20 |
+
hipblasStatusToString(__err), \
|
| 21 |
+
" when calling `" #EXPR "`"); \
|
| 22 |
+
} while (0)
|
| 23 |
+
|
| 24 |
+
namespace at::cuda::tunable {
|
| 25 |
+
|
| 26 |
+
#ifdef HIPBLASLT_HAS_GETINDEXFROMALGO
|
| 27 |
+
#define GETINDEXFROMALGO(algo) hipblaslt_ext::getIndexFromAlgo(algo)
|
| 28 |
+
#else
|
| 29 |
+
static int getIndexFromAlgo(hipblasLtMatmulAlgo_t& algo) {
|
| 30 |
+
int* algo_ptr = (int*)algo.data;
|
| 31 |
+
if(*algo_ptr < 0) {
|
| 32 |
+
return -1;
|
| 33 |
+
}
|
| 34 |
+
return *algo_ptr;
|
| 35 |
+
}
|
| 36 |
+
#define GETINDEXFROMALGO(algo) getIndexFromAlgo(algo)
|
| 37 |
+
#endif
|
| 38 |
+
|
| 39 |
+
#ifdef HIPBLASLT_CUSTOM_COMPUTE_TYPE
|
| 40 |
+
#define COMPUTE_TYPE_32 HIPBLASLT_COMPUTE_F32
|
| 41 |
+
#else
|
| 42 |
+
#define COMPUTE_TYPE_32 HIPBLAS_COMPUTE_32F
|
| 43 |
+
#endif
|
| 44 |
+
|
| 45 |
+
#ifdef HIPBLASLT_CUSTOM_DATA_TYPE
|
| 46 |
+
|
| 47 |
+
template <typename T>
|
| 48 |
+
constexpr hipblasltDatatype_t HipBlasDataTypeFor();
|
| 49 |
+
|
| 50 |
+
template <>
|
| 51 |
+
constexpr hipblasltDatatype_t HipBlasDataTypeFor<float>() {
|
| 52 |
+
return HIPBLASLT_R_32F;
|
| 53 |
+
}
|
| 54 |
+
|
| 55 |
+
template <>
|
| 56 |
+
constexpr hipblasltDatatype_t HipBlasDataTypeFor<Half>() {
|
| 57 |
+
return HIPBLASLT_R_16F;
|
| 58 |
+
}
|
| 59 |
+
|
| 60 |
+
template <>
|
| 61 |
+
constexpr hipblasltDatatype_t HipBlasDataTypeFor<BFloat16>() {
|
| 62 |
+
return HIPBLASLT_R_16B;
|
| 63 |
+
}
|
| 64 |
+
|
| 65 |
+
template <>
|
| 66 |
+
constexpr hipblasltDatatype_t HipBlasDataTypeFor<double>() {
|
| 67 |
+
return HIPBLASLT_R_64F;
|
| 68 |
+
}
|
| 69 |
+
|
| 70 |
+
#define DATA_TYPE_R_32 HIPBLASLT_R_32F
|
| 71 |
+
|
| 72 |
+
#else
|
| 73 |
+
|
| 74 |
+
template <typename T>
|
| 75 |
+
constexpr hipblasDatatype_t HipBlasDataTypeFor();
|
| 76 |
+
|
| 77 |
+
template <>
|
| 78 |
+
constexpr hipblasDatatype_t HipBlasDataTypeFor<float>() {
|
| 79 |
+
return HIPBLAS_R_32F;
|
| 80 |
+
}
|
| 81 |
+
|
| 82 |
+
template <>
|
| 83 |
+
constexpr hipblasDatatype_t HipBlasDataTypeFor<Half>() {
|
| 84 |
+
return HIPBLAS_R_16F;
|
| 85 |
+
}
|
| 86 |
+
|
| 87 |
+
template <>
|
| 88 |
+
constexpr hipblasDatatype_t HipBlasDataTypeFor<BFloat16>() {
|
| 89 |
+
return HIPBLAS_R_16B;
|
| 90 |
+
}
|
| 91 |
+
|
| 92 |
+
template <>
|
| 93 |
+
constexpr hipblasDatatype_t HipBlasDataTypeFor<double>() {
|
| 94 |
+
return HIPBLAS_R_64F;
|
| 95 |
+
}
|
| 96 |
+
|
| 97 |
+
#ifdef HIPBLAS_V2
|
| 98 |
+
#define DATA_TYPE_R_32 HIP_R_32F
|
| 99 |
+
#else
|
| 100 |
+
#define DATA_TYPE_R_32 HIPBLAS_R_32F
|
| 101 |
+
#endif
|
| 102 |
+
|
| 103 |
+
#endif
|
| 104 |
+
|
| 105 |
+
template <typename T, typename ParamsT>
|
| 106 |
+
int GetBatchFromParams(const ParamsT* params) {
|
| 107 |
+
return 1;
|
| 108 |
+
}
|
| 109 |
+
|
| 110 |
+
template <typename T>
|
| 111 |
+
int GetBatchFromParams(const GemmStridedBatchedParams<T>* params) {
|
| 112 |
+
return params->batch;
|
| 113 |
+
}
|
| 114 |
+
|
| 115 |
+
template <typename T, typename ParamsT>
|
| 116 |
+
int GetStrideAFromParams(const ParamsT* params) {
|
| 117 |
+
return 1;
|
| 118 |
+
}
|
| 119 |
+
|
| 120 |
+
template <typename T>
|
| 121 |
+
int GetStrideAFromParams(const GemmStridedBatchedParams<T>* params) {
|
| 122 |
+
return params->stride_a;
|
| 123 |
+
}
|
| 124 |
+
|
| 125 |
+
template <typename T, typename ParamsT>
|
| 126 |
+
int GetStrideBFromParams(const ParamsT* params) {
|
| 127 |
+
return 1;
|
| 128 |
+
}
|
| 129 |
+
|
| 130 |
+
template <typename T>
|
| 131 |
+
int GetStrideBFromParams(const GemmStridedBatchedParams<T>* params) {
|
| 132 |
+
return params->stride_b;
|
| 133 |
+
}
|
| 134 |
+
|
| 135 |
+
template <typename T, typename ParamsT>
|
| 136 |
+
int GetStrideCFromParams(const ParamsT* params) {
|
| 137 |
+
return 1;
|
| 138 |
+
}
|
| 139 |
+
|
| 140 |
+
template <typename T>
|
| 141 |
+
int GetStrideCFromParams(const GemmStridedBatchedParams<T>* params) {
|
| 142 |
+
return params->stride_c;
|
| 143 |
+
}
|
| 144 |
+
|
| 145 |
+
static hipblasOperation_t _hipblasOpFromChar(char op) {
|
| 146 |
+
switch (op) {
|
| 147 |
+
case 'n':
|
| 148 |
+
case 'N':
|
| 149 |
+
return HIPBLAS_OP_N;
|
| 150 |
+
case 't':
|
| 151 |
+
case 'T':
|
| 152 |
+
return HIPBLAS_OP_T;
|
| 153 |
+
case 'c':
|
| 154 |
+
case 'C':
|
| 155 |
+
return HIPBLAS_OP_C;
|
| 156 |
+
}
|
| 157 |
+
AT_ERROR(
|
| 158 |
+
"_hipblasOpFromChar input should be 't', 'n' or 'c' but got `", op, "`");
|
| 159 |
+
}
|
| 160 |
+
|
| 161 |
+
static char _charFromhipblasOp(hipblasOperation_t op) {
|
| 162 |
+
switch (op) {
|
| 163 |
+
case HIPBLAS_OP_N:
|
| 164 |
+
return 'N';
|
| 165 |
+
case HIPBLAS_OP_T:
|
| 166 |
+
return 'T';
|
| 167 |
+
case HIPBLAS_OP_C:
|
| 168 |
+
return 'C';
|
| 169 |
+
}
|
| 170 |
+
AT_ERROR(
|
| 171 |
+
"_charFromhipblasOp input should be HIPBLAS_OP_N/T/C but got `", op, "`");
|
| 172 |
+
}
|
| 173 |
+
|
| 174 |
+
static hipblasOperation_t MapLayoutToHipBlasLt(BlasOp layout) {
|
| 175 |
+
if (layout == BlasOp::N) {
|
| 176 |
+
return HIPBLAS_OP_N;
|
| 177 |
+
}
|
| 178 |
+
return HIPBLAS_OP_T;
|
| 179 |
+
}
|
| 180 |
+
|
| 181 |
+
static size_t GetHipblasltWorkspaceSize() {
|
| 182 |
+
static const char * env = getenv("HIPBLASLT_WORKSPACE_SIZE");
|
| 183 |
+
// 256MB is max workspace size allowed for hipblaslt
|
| 184 |
+
// hipblaslt-bench uses 32MB
|
| 185 |
+
// recommendation from hipblaslt author was 76MB
|
| 186 |
+
size_t workspace_size = 2*128*1024*1024; // default 256MB
|
| 187 |
+
if (env) {
|
| 188 |
+
try {
|
| 189 |
+
workspace_size = std::stoi(env);
|
| 190 |
+
} catch(std::invalid_argument const& e) {
|
| 191 |
+
TORCH_WARN("invalid HIPBLASLT_WORKSPACE_SIZE,",
|
| 192 |
+
" using default workspace size of ", workspace_size, " bytes.");
|
| 193 |
+
} catch(std::out_of_range const& e) {
|
| 194 |
+
TORCH_WARN("HIPBLASLT_WORKSPACE_SIZE out of range,",
|
| 195 |
+
" using default workspace size of ", workspace_size, " bytes.");
|
| 196 |
+
}
|
| 197 |
+
}
|
| 198 |
+
return workspace_size;
|
| 199 |
+
}
|
| 200 |
+
|
| 201 |
+
template <typename T, BlasOp ALayout, BlasOp BLayout, typename ParamsT>
|
| 202 |
+
class HipblasltGemmOp : public Callable<ParamsT> {
|
| 203 |
+
public:
|
| 204 |
+
HipblasltGemmOp(hipblasLtMatmulAlgo_t algo) : algo_{algo} {}
|
| 205 |
+
|
| 206 |
+
TuningStatus Call(const ParamsT* params) override {
|
| 207 |
+
hipblasOperation_t transa_outer = MapLayoutToHipBlasLt(ALayout);
|
| 208 |
+
hipblasOperation_t transb_outer = MapLayoutToHipBlasLt(BLayout);
|
| 209 |
+
auto in_out_datatype = HipBlasDataTypeFor<T>();
|
| 210 |
+
auto opa = _hipblasOpFromChar(params->transa);
|
| 211 |
+
auto opb = _hipblasOpFromChar(params->transb);
|
| 212 |
+
|
| 213 |
+
TORCH_CHECK(transa_outer == opa && transb_outer == opb, "trans mismatch, shouldn't happen");
|
| 214 |
+
|
| 215 |
+
float alpha = static_cast<float>(params->alpha);
|
| 216 |
+
float beta = static_cast<float>(params->beta);
|
| 217 |
+
|
| 218 |
+
hipblasLtMatrixLayout_t mat_a, mat_b, mat_c;
|
| 219 |
+
hipblasLtMatmulDesc_t matmul;
|
| 220 |
+
if (opa == HIPBLAS_OP_N) {
|
| 221 |
+
TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutCreate(&mat_a, in_out_datatype, params->m, params->k, params->lda));
|
| 222 |
+
}
|
| 223 |
+
else {
|
| 224 |
+
TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutCreate(&mat_a, in_out_datatype, params->k, params->m, params->lda));
|
| 225 |
+
}
|
| 226 |
+
if (opb == HIPBLAS_OP_N) {
|
| 227 |
+
TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutCreate(&mat_b, in_out_datatype, params->k, params->n, params->ldb));
|
| 228 |
+
}
|
| 229 |
+
else {
|
| 230 |
+
TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutCreate(&mat_b, in_out_datatype, params->n, params->k, params->ldb));
|
| 231 |
+
}
|
| 232 |
+
TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutCreate(&mat_c, in_out_datatype, params->m, params->n, params->ldc));
|
| 233 |
+
TORCH_HIPBLASLT_CHECK(hipblasLtMatmulDescCreate(&matmul, COMPUTE_TYPE_32, DATA_TYPE_R_32));
|
| 234 |
+
|
| 235 |
+
int batch = GetBatchFromParams<T>(params);
|
| 236 |
+
if (batch > 1) {
|
| 237 |
+
int64_t stride_a = GetStrideAFromParams<T>(params);
|
| 238 |
+
int64_t stride_b = GetStrideBFromParams<T>(params);
|
| 239 |
+
int64_t stride_c = GetStrideCFromParams<T>(params);
|
| 240 |
+
TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutSetAttribute(
|
| 241 |
+
mat_a, HIPBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch, sizeof(batch)));
|
| 242 |
+
TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutSetAttribute(
|
| 243 |
+
mat_a, HIPBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &stride_a, sizeof(stride_a)));
|
| 244 |
+
TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutSetAttribute(
|
| 245 |
+
mat_b, HIPBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch, sizeof(batch)));
|
| 246 |
+
TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutSetAttribute(
|
| 247 |
+
mat_b, HIPBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &stride_b, sizeof(stride_b)));
|
| 248 |
+
TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutSetAttribute(
|
| 249 |
+
mat_c, HIPBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch, sizeof(batch)));
|
| 250 |
+
TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutSetAttribute(
|
| 251 |
+
mat_c, HIPBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &stride_c, sizeof(stride_c)));
|
| 252 |
+
}
|
| 253 |
+
|
| 254 |
+
TORCH_HIPBLASLT_CHECK(hipblasLtMatmulDescSetAttribute(
|
| 255 |
+
matmul, HIPBLASLT_MATMUL_DESC_TRANSA, &opa, sizeof(int32_t)));
|
| 256 |
+
TORCH_HIPBLASLT_CHECK(hipblasLtMatmulDescSetAttribute(
|
| 257 |
+
matmul, HIPBLASLT_MATMUL_DESC_TRANSB, &opb, sizeof(int32_t)));
|
| 258 |
+
|
| 259 |
+
size_t workspace_size = GetHipblasltWorkspaceSize();
|
| 260 |
+
|
| 261 |
+
auto op_handle = at::cuda::getCurrentCUDABlasLtHandle();
|
| 262 |
+
|
| 263 |
+
size_t ret_workspace_size = 0;
|
| 264 |
+
auto status = hipblaslt_ext::matmulIsAlgoSupported(op_handle,
|
| 265 |
+
matmul,
|
| 266 |
+
&alpha,
|
| 267 |
+
mat_a,
|
| 268 |
+
mat_b,
|
| 269 |
+
&beta,
|
| 270 |
+
mat_c,
|
| 271 |
+
mat_c,
|
| 272 |
+
algo_,
|
| 273 |
+
ret_workspace_size);
|
| 274 |
+
|
| 275 |
+
if (status == HIPBLAS_STATUS_SUCCESS) {
|
| 276 |
+
if (ret_workspace_size >= workspace_size) {
|
| 277 |
+
//TUNABLE_LOG("[hipBLASLt] Solution #", algo_index, " workspace too large");
|
| 278 |
+
return FAIL;
|
| 279 |
+
}
|
| 280 |
+
}
|
| 281 |
+
else {
|
| 282 |
+
//TUNABLE_LOG("[hipBLASLt] Solution #", algo_index, " not supported");
|
| 283 |
+
return FAIL;
|
| 284 |
+
}
|
| 285 |
+
|
| 286 |
+
void* workspace_buffer = nullptr;
|
| 287 |
+
if (workspace_size > 0) {
|
| 288 |
+
workspace_buffer = c10::cuda::CUDACachingAllocator::raw_alloc(workspace_size);
|
| 289 |
+
}
|
| 290 |
+
|
| 291 |
+
TORCH_HIPBLASLT_CHECK(hipblasLtMatmul(op_handle,
|
| 292 |
+
matmul,
|
| 293 |
+
&alpha,
|
| 294 |
+
params->a,
|
| 295 |
+
mat_a,
|
| 296 |
+
params->b,
|
| 297 |
+
mat_b,
|
| 298 |
+
&beta,
|
| 299 |
+
params->c,
|
| 300 |
+
mat_c,
|
| 301 |
+
params->c,
|
| 302 |
+
mat_c,
|
| 303 |
+
&algo_,
|
| 304 |
+
workspace_buffer,
|
| 305 |
+
workspace_size,
|
| 306 |
+
at::cuda::getCurrentCUDAStream()));
|
| 307 |
+
|
| 308 |
+
TORCH_HIPBLASLT_CHECK(hipblasLtMatmulDescDestroy(matmul));
|
| 309 |
+
TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutDestroy(mat_a));
|
| 310 |
+
TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutDestroy(mat_b));
|
| 311 |
+
TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutDestroy(mat_c));
|
| 312 |
+
if (workspace_size > 0) {
|
| 313 |
+
c10::cuda::CUDACachingAllocator::raw_delete(workspace_buffer);
|
| 314 |
+
}
|
| 315 |
+
return OK;
|
| 316 |
+
}
|
| 317 |
+
|
| 318 |
+
private:
|
| 319 |
+
hipblasLtMatmulAlgo_t algo_;
|
| 320 |
+
};
|
| 321 |
+
|
| 322 |
+
template <typename T, BlasOp ALayout, BlasOp BLayout, typename ParamsT>
|
| 323 |
+
auto GetHipBlasLtTypeStringAndOps() {
|
| 324 |
+
hipblasOperation_t transa_outer = MapLayoutToHipBlasLt(ALayout);
|
| 325 |
+
hipblasOperation_t transb_outer = MapLayoutToHipBlasLt(BLayout);
|
| 326 |
+
auto in_out_datatype = HipBlasDataTypeFor<T>();
|
| 327 |
+
std::vector<hipblasLtMatmulHeuristicResult_t> heuristic_result;
|
| 328 |
+
|
| 329 |
+
hipblasLtHandle_t handle;
|
| 330 |
+
TORCH_HIPBLASLT_CHECK(hipblasLtCreate(&handle));
|
| 331 |
+
TORCH_HIPBLASLT_CHECK(hipblaslt_ext::getAllAlgos(handle,
|
| 332 |
+
hipblaslt_ext::GemmType::HIPBLASLT_GEMM,
|
| 333 |
+
transa_outer,
|
| 334 |
+
transb_outer,
|
| 335 |
+
in_out_datatype,
|
| 336 |
+
in_out_datatype,
|
| 337 |
+
in_out_datatype,
|
| 338 |
+
in_out_datatype,
|
| 339 |
+
COMPUTE_TYPE_32,
|
| 340 |
+
heuristic_result));
|
| 341 |
+
TORCH_HIPBLASLT_CHECK(hipblasLtDestroy(handle));
|
| 342 |
+
|
| 343 |
+
// Sort heuristic_result by algo index to make sure the order of returned algos is deterministic.
|
| 344 |
+
std::sort(heuristic_result.begin(),
|
| 345 |
+
heuristic_result.end(),
|
| 346 |
+
[](hipblasLtMatmulHeuristicResult_t& a, hipblasLtMatmulHeuristicResult_t& b) {
|
| 347 |
+
return GETINDEXFROMALGO(a.algo) < GETINDEXFROMALGO(b.algo);
|
| 348 |
+
});
|
| 349 |
+
|
| 350 |
+
int returned_algo_count = heuristic_result.size();
|
| 351 |
+
std::vector<std::pair<std::string, std::unique_ptr<Callable<ParamsT>>>> ret;
|
| 352 |
+
for (int i = 0; i < returned_algo_count; i++) {
|
| 353 |
+
auto algo = heuristic_result[i].algo;
|
| 354 |
+
int algo_index = GETINDEXFROMALGO(algo);
|
| 355 |
+
auto callable = std::make_unique<HipblasltGemmOp<T, ALayout, BLayout, ParamsT>>(algo);
|
| 356 |
+
std::string type_string = c10::str(
|
| 357 |
+
"Gemm_Hipblaslt_", _charFromhipblasOp(transa_outer), _charFromhipblasOp(transb_outer), "_", algo_index);
|
| 358 |
+
ret.emplace_back(type_string, std::move(callable));
|
| 359 |
+
}
|
| 360 |
+
|
| 361 |
+
return ret;
|
| 362 |
+
}
|
| 363 |
+
|
| 364 |
+
template <typename T, BlasOp ALayout, BlasOp BLayout>
|
| 365 |
+
auto GetHipBlasLtGemmTypeStringAndOps() {
|
| 366 |
+
return GetHipBlasLtTypeStringAndOps<T, ALayout, BLayout, GemmParams<T>>();
|
| 367 |
+
}
|
| 368 |
+
|
| 369 |
+
template <typename T, BlasOp ALayout, BlasOp BLayout>
|
| 370 |
+
auto GetHipBlasLtGemmStridedBatchedTypeStringAndOps() {
|
| 371 |
+
return GetHipBlasLtTypeStringAndOps<T, ALayout, BLayout, GemmStridedBatchedParams<T>>();
|
| 372 |
+
}
|
| 373 |
+
|
| 374 |
+
#undef TORCH_HIPBLASLT_CHECK
|
| 375 |
+
#undef GETINDEXFROMALGO
|
| 376 |
+
#undef COMPUTE_TYPE_32
|
| 377 |
+
#undef DATA_TYPE_R_32
|
| 378 |
+
|
| 379 |
+
} // namespace at::cuda::tunable
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/tunable/StreamTimer.h
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Original TunableOp is from onnxruntime.
|
| 2 |
+
// https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/core/framework/tunable.h
|
| 3 |
+
// https://github.com/microsoft/onnxruntime/tree/main/onnxruntime/core/providers/rocm/tunable
|
| 4 |
+
// Copyright (c) Microsoft Corporation.
|
| 5 |
+
// Licensed under the MIT license.
|
| 6 |
+
//
|
| 7 |
+
// Adapting TunableOp into PyTorch
|
| 8 |
+
// Copyright (c) Advanced Micro Devices, Inc.
|
| 9 |
+
//
|
| 10 |
+
#pragma once
|
| 11 |
+
|
| 12 |
+
#include <cuda_runtime.h>
|
| 13 |
+
|
| 14 |
+
#include <ATen/cuda/tunable/Tunable.h>
|
| 15 |
+
|
| 16 |
+
namespace at::cuda::tunable {
|
| 17 |
+
|
| 18 |
+
class StreamTimer : public ITimer {
|
| 19 |
+
public:
|
| 20 |
+
StreamTimer();
|
| 21 |
+
virtual ~StreamTimer();
|
| 22 |
+
|
| 23 |
+
void Start() override;
|
| 24 |
+
|
| 25 |
+
void End() override;
|
| 26 |
+
|
| 27 |
+
float Duration() override;
|
| 28 |
+
|
| 29 |
+
private:
|
| 30 |
+
cudaEvent_t start_;
|
| 31 |
+
cudaEvent_t end_;
|
| 32 |
+
};
|
| 33 |
+
|
| 34 |
+
} // namespace at::cuda::tunable
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/TensorConversions.h
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <c10/core/Device.h>
|
| 4 |
+
#include <c10/core/Layout.h>
|
| 5 |
+
#include <c10/core/MemoryFormat.h>
|
| 6 |
+
#include <c10/core/ScalarType.h>
|
| 7 |
+
#include <c10/util/Optional.h>
|
| 8 |
+
|
| 9 |
+
namespace at {
|
| 10 |
+
class Tensor;
|
| 11 |
+
namespace native {
|
| 12 |
+
bool to_will_alias(
|
| 13 |
+
const Tensor& self,
|
| 14 |
+
c10::optional<ScalarType> dtype,
|
| 15 |
+
c10::optional<Layout> layout,
|
| 16 |
+
c10::optional<Device> device,
|
| 17 |
+
bool copy,
|
| 18 |
+
c10::optional<c10::MemoryFormat> optional_memory_format);
|
| 19 |
+
|
| 20 |
+
Tensor to_meta(const Tensor& tensor);
|
| 21 |
+
c10::optional<Tensor> to_meta(const c10::optional<Tensor>& tensor);
|
| 22 |
+
std::vector<Tensor> to_meta(at::ITensorListRef t_list);
|
| 23 |
+
Tensor dense_to_sparse_with_mask(const Tensor& self, const Tensor& mask, c10::optional<c10::Layout> layout, OptionalIntArrayRef blocksize, c10::optional<int64_t> dense_dim_opt);
|
| 24 |
+
|
| 25 |
+
} // namespace native
|
| 26 |
+
} // namespace at
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/group_norm.h
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/native/DispatchStub.h>
|
| 4 |
+
#include <cstdint>
|
| 5 |
+
|
| 6 |
+
namespace at {
|
| 7 |
+
class Tensor;
|
| 8 |
+
|
| 9 |
+
namespace native {
|
| 10 |
+
|
| 11 |
+
using forward_fn = void (*)(
|
| 12 |
+
const Tensor& /* X */,
|
| 13 |
+
const Tensor& /* gamma */,
|
| 14 |
+
const Tensor& /* beta */,
|
| 15 |
+
int64_t /* N */,
|
| 16 |
+
int64_t /* C */,
|
| 17 |
+
int64_t /* HxW */,
|
| 18 |
+
int64_t /* group */,
|
| 19 |
+
double /* eps */,
|
| 20 |
+
Tensor& /* Y */,
|
| 21 |
+
Tensor& /* mean */,
|
| 22 |
+
Tensor& /* rstd */);
|
| 23 |
+
|
| 24 |
+
using backward_fn = void (*)(
|
| 25 |
+
const Tensor& /* dY */,
|
| 26 |
+
const Tensor& /* X */,
|
| 27 |
+
const Tensor& /* mean */,
|
| 28 |
+
const Tensor& /* rstd */,
|
| 29 |
+
const Tensor& /* gamma */,
|
| 30 |
+
int64_t /* N */,
|
| 31 |
+
int64_t /* C */,
|
| 32 |
+
int64_t /* HxW */,
|
| 33 |
+
int64_t /* group */,
|
| 34 |
+
Tensor& /* dX */,
|
| 35 |
+
Tensor& /* dgamma */,
|
| 36 |
+
Tensor& /* dbeta */);
|
| 37 |
+
|
| 38 |
+
DECLARE_DISPATCH(forward_fn, GroupNormKernel);
|
| 39 |
+
DECLARE_DISPATCH(backward_fn, GroupNormBackwardKernel);
|
| 40 |
+
|
| 41 |
+
} // namespace native
|
| 42 |
+
} // namespace at
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_cdist_forward_native.h
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
// @generated by torchgen/gen.py from NativeFunction.h
|
| 4 |
+
|
| 5 |
+
#include <c10/core/Scalar.h>
|
| 6 |
+
#include <c10/core/Storage.h>
|
| 7 |
+
#include <c10/core/TensorOptions.h>
|
| 8 |
+
#include <c10/util/Deprecated.h>
|
| 9 |
+
#include <c10/util/Optional.h>
|
| 10 |
+
#include <c10/core/QScheme.h>
|
| 11 |
+
#include <ATen/core/Reduction.h>
|
| 12 |
+
#include <ATen/core/Tensor.h>
|
| 13 |
+
#include <tuple>
|
| 14 |
+
#include <vector>
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
namespace at {
|
| 18 |
+
namespace native {
|
| 19 |
+
TORCH_API at::Tensor & _cdist_forward_out(const at::Tensor & x1, const at::Tensor & x2, double p, c10::optional<int64_t> compute_mode, at::Tensor & out);
|
| 20 |
+
TORCH_API at::Tensor _cdist_forward(const at::Tensor & x1, const at::Tensor & x2, double p, c10::optional<int64_t> compute_mode);
|
| 21 |
+
} // namespace native
|
| 22 |
+
} // namespace at
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_foreach_clamp_min_cpu_dispatch.h
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
// @generated by torchgen/gen.py from DispatchKeyFunction.h
|
| 3 |
+
|
| 4 |
+
// NB: The implementing C++ file is RegisterDispatchKey.cpp
|
| 5 |
+
|
| 6 |
+
// The only #includes we need are for custom classes that have defaults in the C++ API
|
| 7 |
+
#include <c10/core/MemoryFormat.h>
|
| 8 |
+
#include <c10/core/Scalar.h>
|
| 9 |
+
#include <ATen/core/Reduction.h>
|
| 10 |
+
|
| 11 |
+
// Forward declarations of any types needed in the operator signatures.
|
| 12 |
+
// We can't directly include these classes because it will cause circular include dependencies.
|
| 13 |
+
// This file is included by TensorBody.h, which defines the Tensor class.
|
| 14 |
+
#include <ATen/core/ATen_fwd.h>
|
| 15 |
+
|
| 16 |
+
namespace at {
|
| 17 |
+
|
| 18 |
+
namespace cpu {
|
| 19 |
+
|
| 20 |
+
TORCH_API ::std::vector<at::Tensor> _foreach_clamp_min(at::TensorList self, const at::Scalar & scalar);
|
| 21 |
+
TORCH_API void _foreach_clamp_min_(at::TensorList self, const at::Scalar & scalar);
|
| 22 |
+
TORCH_API ::std::vector<at::Tensor> _foreach_clamp_min(at::TensorList self, at::TensorList other);
|
| 23 |
+
TORCH_API void _foreach_clamp_min_(at::TensorList self, at::TensorList other);
|
| 24 |
+
TORCH_API ::std::vector<at::Tensor> _foreach_clamp_min(at::TensorList self, at::ArrayRef<at::Scalar> scalars);
|
| 25 |
+
TORCH_API void _foreach_clamp_min_(at::TensorList self, at::ArrayRef<at::Scalar> scalars);
|
| 26 |
+
|
| 27 |
+
} // namespace cpu
|
| 28 |
+
} // namespace at
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_log_softmax_meta_dispatch.h
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
// @generated by torchgen/gen.py from DispatchKeyFunction.h
|
| 3 |
+
|
| 4 |
+
// NB: The implementing C++ file is RegisterDispatchKey.cpp
|
| 5 |
+
|
| 6 |
+
// The only #includes we need are for custom classes that have defaults in the C++ API
|
| 7 |
+
#include <c10/core/MemoryFormat.h>
|
| 8 |
+
#include <c10/core/Scalar.h>
|
| 9 |
+
#include <ATen/core/Reduction.h>
|
| 10 |
+
|
| 11 |
+
// Forward declarations of any types needed in the operator signatures.
|
| 12 |
+
// We can't directly include these classes because it will cause circular include dependencies.
|
| 13 |
+
// This file is included by TensorBody.h, which defines the Tensor class.
|
| 14 |
+
#include <ATen/core/ATen_fwd.h>
|
| 15 |
+
|
| 16 |
+
namespace at {
|
| 17 |
+
|
| 18 |
+
namespace meta {
|
| 19 |
+
|
| 20 |
+
TORCH_API at::Tensor _log_softmax(const at::Tensor & self, int64_t dim, bool half_to_float);
|
| 21 |
+
TORCH_API at::Tensor & _log_softmax_out(at::Tensor & out, const at::Tensor & self, int64_t dim, bool half_to_float);
|
| 22 |
+
TORCH_API at::Tensor & _log_softmax_outf(const at::Tensor & self, int64_t dim, bool half_to_float, at::Tensor & out);
|
| 23 |
+
|
| 24 |
+
} // namespace meta
|
| 25 |
+
} // namespace at
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_make_per_tensor_quantized_tensor_cpu_dispatch.h
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
// @generated by torchgen/gen.py from DispatchKeyFunction.h
|
| 3 |
+
|
| 4 |
+
// NB: The implementing C++ file is RegisterDispatchKey.cpp
|
| 5 |
+
|
| 6 |
+
// The only #includes we need are for custom classes that have defaults in the C++ API
|
| 7 |
+
#include <c10/core/MemoryFormat.h>
|
| 8 |
+
#include <c10/core/Scalar.h>
|
| 9 |
+
#include <ATen/core/Reduction.h>
|
| 10 |
+
|
| 11 |
+
// Forward declarations of any types needed in the operator signatures.
|
| 12 |
+
// We can't directly include these classes because it will cause circular include dependencies.
|
| 13 |
+
// This file is included by TensorBody.h, which defines the Tensor class.
|
| 14 |
+
#include <ATen/core/ATen_fwd.h>
|
| 15 |
+
|
| 16 |
+
namespace at {
|
| 17 |
+
|
| 18 |
+
namespace cpu {
|
| 19 |
+
|
| 20 |
+
TORCH_API at::Tensor _make_per_tensor_quantized_tensor(const at::Tensor & self, double scale, int64_t zero_point);
|
| 21 |
+
|
| 22 |
+
} // namespace cpu
|
| 23 |
+
} // namespace at
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_masked_softmax_backward_compositeexplicitautograd_dispatch.h
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
// @generated by torchgen/gen.py from DispatchKeyFunction.h
|
| 3 |
+
|
| 4 |
+
// NB: The implementing C++ file is RegisterDispatchKey.cpp
|
| 5 |
+
|
| 6 |
+
// The only #includes we need are for custom classes that have defaults in the C++ API
|
| 7 |
+
#include <c10/core/MemoryFormat.h>
|
| 8 |
+
#include <c10/core/Scalar.h>
|
| 9 |
+
#include <ATen/core/Reduction.h>
|
| 10 |
+
|
| 11 |
+
// Forward declarations of any types needed in the operator signatures.
|
| 12 |
+
// We can't directly include these classes because it will cause circular include dependencies.
|
| 13 |
+
// This file is included by TensorBody.h, which defines the Tensor class.
|
| 14 |
+
#include <ATen/core/ATen_fwd.h>
|
| 15 |
+
|
| 16 |
+
namespace at {
|
| 17 |
+
|
| 18 |
+
namespace compositeexplicitautograd {
|
| 19 |
+
|
| 20 |
+
TORCH_API at::Tensor & _masked_softmax_backward_out(at::Tensor & out, const at::Tensor & grad_output, const at::Tensor & output, const at::Tensor & mask, c10::optional<int64_t> dim=c10::nullopt);
|
| 21 |
+
TORCH_API at::Tensor & _masked_softmax_backward_outf(const at::Tensor & grad_output, const at::Tensor & output, const at::Tensor & mask, c10::optional<int64_t> dim, at::Tensor & out);
|
| 22 |
+
|
| 23 |
+
} // namespace compositeexplicitautograd
|
| 24 |
+
} // namespace at
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_masked_softmax_backward_native.h
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
// @generated by torchgen/gen.py from NativeFunction.h
|
| 4 |
+
|
| 5 |
+
#include <c10/core/Scalar.h>
|
| 6 |
+
#include <c10/core/Storage.h>
|
| 7 |
+
#include <c10/core/TensorOptions.h>
|
| 8 |
+
#include <c10/util/Deprecated.h>
|
| 9 |
+
#include <c10/util/Optional.h>
|
| 10 |
+
#include <c10/core/QScheme.h>
|
| 11 |
+
#include <ATen/core/Reduction.h>
|
| 12 |
+
#include <ATen/core/Tensor.h>
|
| 13 |
+
#include <tuple>
|
| 14 |
+
#include <vector>
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
namespace at {
|
| 18 |
+
namespace native {
|
| 19 |
+
TORCH_API at::Tensor & _masked_softmax_backward_out(const at::Tensor & grad_output, const at::Tensor & output, const at::Tensor & mask, c10::optional<int64_t> dim, at::Tensor & out);
|
| 20 |
+
TORCH_API at::Tensor masked_softmax_backward_cpu(const at::Tensor & grad_output, const at::Tensor & output, const at::Tensor & mask, c10::optional<int64_t> dim=c10::nullopt);
|
| 21 |
+
TORCH_API at::Tensor masked_softmax_backward_cuda(const at::Tensor & grad_output, const at::Tensor & output, const at::Tensor & mask, c10::optional<int64_t> dim=c10::nullopt);
|
| 22 |
+
} // namespace native
|
| 23 |
+
} // namespace at
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_masked_softmax_cpu_dispatch.h
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
// @generated by torchgen/gen.py from DispatchKeyFunction.h
|
| 3 |
+
|
| 4 |
+
// NB: The implementing C++ file is RegisterDispatchKey.cpp
|
| 5 |
+
|
| 6 |
+
// The only #includes we need are for custom classes that have defaults in the C++ API
|
| 7 |
+
#include <c10/core/MemoryFormat.h>
|
| 8 |
+
#include <c10/core/Scalar.h>
|
| 9 |
+
#include <ATen/core/Reduction.h>
|
| 10 |
+
|
| 11 |
+
// Forward declarations of any types needed in the operator signatures.
|
| 12 |
+
// We can't directly include these classes because it will cause circular include dependencies.
|
| 13 |
+
// This file is included by TensorBody.h, which defines the Tensor class.
|
| 14 |
+
#include <ATen/core/ATen_fwd.h>
|
| 15 |
+
|
| 16 |
+
namespace at {
|
| 17 |
+
|
| 18 |
+
namespace cpu {
|
| 19 |
+
|
| 20 |
+
TORCH_API at::Tensor _masked_softmax(const at::Tensor & self, const at::Tensor & mask, c10::optional<int64_t> dim=c10::nullopt, c10::optional<int64_t> mask_type=c10::nullopt);
|
| 21 |
+
|
| 22 |
+
} // namespace cpu
|
| 23 |
+
} // namespace at
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_nnpack_spatial_convolution_ops.h
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
// @generated by torchgen/gen.py from Operator.h
|
| 4 |
+
|
| 5 |
+
#include <tuple>
|
| 6 |
+
#include <vector>
|
| 7 |
+
|
| 8 |
+
// Forward declarations of any types needed in the operator signatures.
|
| 9 |
+
// We can't directly include these classes because it will cause circular include dependencies.
|
| 10 |
+
// This file is included by TensorBody.h, which defines the Tensor class.
|
| 11 |
+
#include <ATen/core/ATen_fwd.h>
|
| 12 |
+
|
| 13 |
+
namespace at {
|
| 14 |
+
namespace _ops {
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
struct TORCH_API _nnpack_spatial_convolution {
|
| 18 |
+
using schema = at::Tensor (const at::Tensor &, const at::Tensor &, const c10::optional<at::Tensor> &, c10::SymIntArrayRef, c10::SymIntArrayRef);
|
| 19 |
+
using ptr_schema = schema*;
|
| 20 |
+
// See Note [static constexpr char* members for windows NVCC]
|
| 21 |
+
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::_nnpack_spatial_convolution")
|
| 22 |
+
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "")
|
| 23 |
+
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, "_nnpack_spatial_convolution(Tensor input, Tensor weight, Tensor? bias, SymInt[2] padding, SymInt[2] stride=1) -> Tensor")
|
| 24 |
+
static at::Tensor call(const at::Tensor & input, const at::Tensor & weight, const c10::optional<at::Tensor> & bias, c10::SymIntArrayRef padding, c10::SymIntArrayRef stride);
|
| 25 |
+
static at::Tensor redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & weight, const c10::optional<at::Tensor> & bias, c10::SymIntArrayRef padding, c10::SymIntArrayRef stride);
|
| 26 |
+
};
|
| 27 |
+
|
| 28 |
+
struct TORCH_API _nnpack_spatial_convolution_out {
|
| 29 |
+
using schema = at::Tensor & (const at::Tensor &, const at::Tensor &, const c10::optional<at::Tensor> &, c10::SymIntArrayRef, c10::SymIntArrayRef, at::Tensor &);
|
| 30 |
+
using ptr_schema = schema*;
|
| 31 |
+
// See Note [static constexpr char* members for windows NVCC]
|
| 32 |
+
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::_nnpack_spatial_convolution")
|
| 33 |
+
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "out")
|
| 34 |
+
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, "_nnpack_spatial_convolution.out(Tensor input, Tensor weight, Tensor? bias, SymInt[2] padding, SymInt[2] stride=1, *, Tensor(a!) out) -> Tensor(a!)")
|
| 35 |
+
static at::Tensor & call(const at::Tensor & input, const at::Tensor & weight, const c10::optional<at::Tensor> & bias, c10::SymIntArrayRef padding, c10::SymIntArrayRef stride, at::Tensor & out);
|
| 36 |
+
static at::Tensor & redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & weight, const c10::optional<at::Tensor> & bias, c10::SymIntArrayRef padding, c10::SymIntArrayRef stride, at::Tensor & out);
|
| 37 |
+
};
|
| 38 |
+
|
| 39 |
+
}} // namespace at::_ops
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_scaled_dot_product_efficient_attention_backward_ops.h
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
// @generated by torchgen/gen.py from Operator.h
|
| 4 |
+
|
| 5 |
+
#include <tuple>
|
| 6 |
+
#include <vector>
|
| 7 |
+
|
| 8 |
+
// Forward declarations of any types needed in the operator signatures.
|
| 9 |
+
// We can't directly include these classes because it will cause circular include dependencies.
|
| 10 |
+
// This file is included by TensorBody.h, which defines the Tensor class.
|
| 11 |
+
#include <ATen/core/ATen_fwd.h>
|
| 12 |
+
|
| 13 |
+
namespace at {
|
| 14 |
+
namespace _ops {
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
struct TORCH_API _scaled_dot_product_efficient_attention_backward {
|
| 18 |
+
using schema = ::std::tuple<at::Tensor,at::Tensor,at::Tensor,at::Tensor> (const at::Tensor &, const at::Tensor &, const at::Tensor &, const at::Tensor &, const at::Tensor &, const at::Tensor &, const at::Tensor &, const at::Tensor &, const at::Tensor &, double, ::std::array<bool,4>, bool, c10::optional<double>);
|
| 19 |
+
using ptr_schema = schema*;
|
| 20 |
+
// See Note [static constexpr char* members for windows NVCC]
|
| 21 |
+
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::_scaled_dot_product_efficient_attention_backward")
|
| 22 |
+
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "")
|
| 23 |
+
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, "_scaled_dot_product_efficient_attention_backward(Tensor grad_out_, Tensor query, Tensor key, Tensor value, Tensor attn_bias, Tensor out, Tensor logsumexp, Tensor philox_seed, Tensor philox_offset, float dropout_p, bool[4] grad_input_mask, bool is_causal=False, *, float? scale=None) -> (Tensor, Tensor, Tensor, Tensor)")
|
| 24 |
+
static ::std::tuple<at::Tensor,at::Tensor,at::Tensor,at::Tensor> call(const at::Tensor & grad_out_, const at::Tensor & query, const at::Tensor & key, const at::Tensor & value, const at::Tensor & attn_bias, const at::Tensor & out, const at::Tensor & logsumexp, const at::Tensor & philox_seed, const at::Tensor & philox_offset, double dropout_p, ::std::array<bool,4> grad_input_mask, bool is_causal, c10::optional<double> scale);
|
| 25 |
+
static ::std::tuple<at::Tensor,at::Tensor,at::Tensor,at::Tensor> redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_out_, const at::Tensor & query, const at::Tensor & key, const at::Tensor & value, const at::Tensor & attn_bias, const at::Tensor & out, const at::Tensor & logsumexp, const at::Tensor & philox_seed, const at::Tensor & philox_offset, double dropout_p, ::std::array<bool,4> grad_input_mask, bool is_causal, c10::optional<double> scale);
|
| 26 |
+
};
|
| 27 |
+
|
| 28 |
+
}} // namespace at::_ops
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_test_serialization_subcmul_ops.h
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
// @generated by torchgen/gen.py from Operator.h
|
| 4 |
+
|
| 5 |
+
#include <tuple>
|
| 6 |
+
#include <vector>
|
| 7 |
+
|
| 8 |
+
// Forward declarations of any types needed in the operator signatures.
|
| 9 |
+
// We can't directly include these classes because it will cause circular include dependencies.
|
| 10 |
+
// This file is included by TensorBody.h, which defines the Tensor class.
|
| 11 |
+
#include <ATen/core/ATen_fwd.h>
|
| 12 |
+
|
| 13 |
+
namespace at {
|
| 14 |
+
namespace _ops {
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
struct TORCH_API _test_serialization_subcmul {
|
| 18 |
+
using schema = at::Tensor (const at::Tensor &, const at::Tensor &, const at::Scalar &);
|
| 19 |
+
using ptr_schema = schema*;
|
| 20 |
+
// See Note [static constexpr char* members for windows NVCC]
|
| 21 |
+
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::_test_serialization_subcmul")
|
| 22 |
+
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "")
|
| 23 |
+
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, "_test_serialization_subcmul(Tensor self, Tensor other, Scalar alpha=1) -> Tensor")
|
| 24 |
+
static at::Tensor call(const at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha);
|
| 25 |
+
static at::Tensor redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha);
|
| 26 |
+
};
|
| 27 |
+
|
| 28 |
+
}} // namespace at::_ops
|