Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .venv/lib/python3.11/site-packages/torch/include/ATen/ArrayRef.h +2 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/DimVector.h +2 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/Dimname.h +1 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/Formatting.h +1 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/MetaFunctions_inl.h +325 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/PTThreadPool.h +17 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/SequenceNumber.h +13 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/ThreadLocalPythonObjects.h +21 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/ThreadLocalState.h +120 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/Version.h +18 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/cuda/AsmUtils.cuh +149 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/cuda/CUDAContextLight.h +99 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/cuda/CUDADataType.h +105 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/cuda/CUDADevice.h +23 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/cuda/CUDAUtils.h +20 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/cuda/NumericLimits.cuh +121 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/cuda/PeerToPeerAccess.h +11 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/cuda/Sleep.h +13 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/cuda/cub_definitions.cuh +53 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/cuda/detail/CUDAHooks.h +58 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/cuda/detail/DeviceThreadHandles.h +151 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/cuda/detail/IntegerDivider.cuh +124 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/cuda/detail/KernelUtils.h +37 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/cuda/detail/LazyNVRTC.h +11 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/cuda/detail/PhiloxCudaStateRaw.cuh +43 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/cuda/detail/TensorInfo.cuh +116 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/cuda/detail/UnpackRaw.cuh +28 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/cuda/llvm_jit_strings.h +14 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/cuda/tunable/GemmCommon.h +397 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/cuda/tunable/GemmHipblaslt.h +611 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/cuda/tunable/GemmRocblas.h +275 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/cuda/tunable/StreamTimer.h +34 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/cuda/tunable/Tunable.h +246 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/cuda/tunable/TunableGemm.h +307 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/cuda/tunable/TunableOp.h +286 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/native/Activation.h +98 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/native/AdaptivePooling.h +49 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/native/BatchLinearAlgebra.h +321 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/native/BinaryOps.h +119 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/native/ComplexHelper.h +97 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/native/Distance.h +20 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/native/Fill.h +21 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/native/FractionalMaxPooling.h +80 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/native/FunctionOfAMatrixUtils.h +20 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/native/GridSampler.h +298 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/native/LossMulti.h +69 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/native/Math.h +0 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/native/MathBitFallThroughLists.h +71 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/native/NonEmptyUtils.h +27 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/native/Normalization.h +19 -0
.venv/lib/python3.11/site-packages/torch/include/ATen/ArrayRef.h
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
#include <c10/util/ArrayRef.h>
|
.venv/lib/python3.11/site-packages/torch/include/ATen/DimVector.h
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
#include <ATen/core/DimVector.h>
|
.venv/lib/python3.11/site-packages/torch/include/ATen/Dimname.h
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
#include <ATen/core/Dimname.h>
|
.venv/lib/python3.11/site-packages/torch/include/ATen/Formatting.h
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
#include <ATen/core/Formatting.h>
|
.venv/lib/python3.11/site-packages/torch/include/ATen/MetaFunctions_inl.h
ADDED
|
@@ -0,0 +1,325 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
// @generated by torchgen/gen.py from DispatchKeyFunctions_inl.h
|
| 3 |
+
|
| 4 |
+
// NB: The implementing C++ file is RegisterDispatchKey.cpp
|
| 5 |
+
|
| 6 |
+
// The only #includes we need are for custom classes that have defaults in the C++ API
|
| 7 |
+
#include <c10/core/MemoryFormat.h>
|
| 8 |
+
#include <c10/core/Scalar.h>
|
| 9 |
+
#include <ATen/core/Reduction.h>
|
| 10 |
+
|
| 11 |
+
#if defined(AT_PER_OPERATOR_HEADERS) && defined(TORCH_ASSERT_ONLY_METHOD_OPERATORS)
|
| 12 |
+
#error This change adds a dependency on all pytorch operators, meaning the \
|
| 13 |
+
file will need to be re-compiled every time an operator is changed or added. \
|
| 14 |
+
Consider including a specific operator from \
|
| 15 |
+
<ATen/ops/{my_operator}_meta_dispatch.h>. \
|
| 16 |
+
See NOTE [TORCH_ASSERT_ONLY_METHOD_OPERATORS].
|
| 17 |
+
#endif
|
| 18 |
+
|
| 19 |
+
#include <ATen/ops/_add_relu_meta_dispatch.h>
|
| 20 |
+
#include <ATen/ops/_addmm_activation_meta_dispatch.h>
|
| 21 |
+
#include <ATen/ops/_amp_update_scale_meta_dispatch.h>
|
| 22 |
+
#include <ATen/ops/_coalesced_meta_dispatch.h>
|
| 23 |
+
#include <ATen/ops/_convert_indices_from_coo_to_csr_meta_dispatch.h>
|
| 24 |
+
#include <ATen/ops/_convert_indices_from_csr_to_coo_meta_dispatch.h>
|
| 25 |
+
#include <ATen/ops/_ctc_loss_meta_dispatch.h>
|
| 26 |
+
#include <ATen/ops/_efficientzerotensor_meta_dispatch.h>
|
| 27 |
+
#include <ATen/ops/_fill_mem_eff_dropout_mask_meta_dispatch.h>
|
| 28 |
+
#include <ATen/ops/_fused_sdp_choice_meta_dispatch.h>
|
| 29 |
+
#include <ATen/ops/_index_put_impl_meta_dispatch.h>
|
| 30 |
+
#include <ATen/ops/_linalg_det_meta_dispatch.h>
|
| 31 |
+
#include <ATen/ops/_linalg_eigh_meta_dispatch.h>
|
| 32 |
+
#include <ATen/ops/_linalg_slogdet_meta_dispatch.h>
|
| 33 |
+
#include <ATen/ops/_linalg_solve_ex_meta_dispatch.h>
|
| 34 |
+
#include <ATen/ops/_linalg_svd_meta_dispatch.h>
|
| 35 |
+
#include <ATen/ops/_log_softmax_meta_dispatch.h>
|
| 36 |
+
#include <ATen/ops/_log_softmax_backward_data_meta_dispatch.h>
|
| 37 |
+
#include <ATen/ops/_mkldnn_transpose_meta_dispatch.h>
|
| 38 |
+
#include <ATen/ops/_reshape_alias_meta_dispatch.h>
|
| 39 |
+
#include <ATen/ops/_resize_output_meta_dispatch.h>
|
| 40 |
+
#include <ATen/ops/_softmax_meta_dispatch.h>
|
| 41 |
+
#include <ATen/ops/_softmax_backward_data_meta_dispatch.h>
|
| 42 |
+
#include <ATen/ops/_sparse_coo_tensor_with_dims_meta_dispatch.h>
|
| 43 |
+
#include <ATen/ops/_sparse_coo_tensor_with_dims_and_tensors_meta_dispatch.h>
|
| 44 |
+
#include <ATen/ops/_upsample_bicubic2d_aa_meta_dispatch.h>
|
| 45 |
+
#include <ATen/ops/_upsample_bicubic2d_aa_backward_meta_dispatch.h>
|
| 46 |
+
#include <ATen/ops/_upsample_bilinear2d_aa_meta_dispatch.h>
|
| 47 |
+
#include <ATen/ops/_upsample_bilinear2d_aa_backward_meta_dispatch.h>
|
| 48 |
+
#include <ATen/ops/_upsample_nearest_exact1d_meta_dispatch.h>
|
| 49 |
+
#include <ATen/ops/_upsample_nearest_exact1d_backward_meta_dispatch.h>
|
| 50 |
+
#include <ATen/ops/_upsample_nearest_exact2d_meta_dispatch.h>
|
| 51 |
+
#include <ATen/ops/_upsample_nearest_exact2d_backward_meta_dispatch.h>
|
| 52 |
+
#include <ATen/ops/_upsample_nearest_exact3d_meta_dispatch.h>
|
| 53 |
+
#include <ATen/ops/_upsample_nearest_exact3d_backward_meta_dispatch.h>
|
| 54 |
+
#include <ATen/ops/acos_meta_dispatch.h>
|
| 55 |
+
#include <ATen/ops/acosh_meta_dispatch.h>
|
| 56 |
+
#include <ATen/ops/adaptive_max_pool2d_meta_dispatch.h>
|
| 57 |
+
#include <ATen/ops/adaptive_max_pool2d_backward_meta_dispatch.h>
|
| 58 |
+
#include <ATen/ops/adaptive_max_pool3d_meta_dispatch.h>
|
| 59 |
+
#include <ATen/ops/adaptive_max_pool3d_backward_meta_dispatch.h>
|
| 60 |
+
#include <ATen/ops/add_meta_dispatch.h>
|
| 61 |
+
#include <ATen/ops/addbmm_meta_dispatch.h>
|
| 62 |
+
#include <ATen/ops/addcdiv_meta_dispatch.h>
|
| 63 |
+
#include <ATen/ops/addcmul_meta_dispatch.h>
|
| 64 |
+
#include <ATen/ops/addmm_meta_dispatch.h>
|
| 65 |
+
#include <ATen/ops/addmv_meta_dispatch.h>
|
| 66 |
+
#include <ATen/ops/all_meta_dispatch.h>
|
| 67 |
+
#include <ATen/ops/amax_meta_dispatch.h>
|
| 68 |
+
#include <ATen/ops/amin_meta_dispatch.h>
|
| 69 |
+
#include <ATen/ops/aminmax_meta_dispatch.h>
|
| 70 |
+
#include <ATen/ops/any_meta_dispatch.h>
|
| 71 |
+
#include <ATen/ops/arange_meta_dispatch.h>
|
| 72 |
+
#include <ATen/ops/argmax_meta_dispatch.h>
|
| 73 |
+
#include <ATen/ops/argmin_meta_dispatch.h>
|
| 74 |
+
#include <ATen/ops/as_strided_meta_dispatch.h>
|
| 75 |
+
#include <ATen/ops/asin_meta_dispatch.h>
|
| 76 |
+
#include <ATen/ops/asinh_meta_dispatch.h>
|
| 77 |
+
#include <ATen/ops/atan_meta_dispatch.h>
|
| 78 |
+
#include <ATen/ops/atan2_meta_dispatch.h>
|
| 79 |
+
#include <ATen/ops/atanh_meta_dispatch.h>
|
| 80 |
+
#include <ATen/ops/avg_pool2d_meta_dispatch.h>
|
| 81 |
+
#include <ATen/ops/avg_pool2d_backward_meta_dispatch.h>
|
| 82 |
+
#include <ATen/ops/avg_pool3d_meta_dispatch.h>
|
| 83 |
+
#include <ATen/ops/avg_pool3d_backward_meta_dispatch.h>
|
| 84 |
+
#include <ATen/ops/baddbmm_meta_dispatch.h>
|
| 85 |
+
#include <ATen/ops/bernoulli_meta_dispatch.h>
|
| 86 |
+
#include <ATen/ops/bitwise_and_meta_dispatch.h>
|
| 87 |
+
#include <ATen/ops/bitwise_left_shift_meta_dispatch.h>
|
| 88 |
+
#include <ATen/ops/bitwise_not_meta_dispatch.h>
|
| 89 |
+
#include <ATen/ops/bitwise_or_meta_dispatch.h>
|
| 90 |
+
#include <ATen/ops/bitwise_right_shift_meta_dispatch.h>
|
| 91 |
+
#include <ATen/ops/bitwise_xor_meta_dispatch.h>
|
| 92 |
+
#include <ATen/ops/bmm_meta_dispatch.h>
|
| 93 |
+
#include <ATen/ops/cat_meta_dispatch.h>
|
| 94 |
+
#include <ATen/ops/cauchy_meta_dispatch.h>
|
| 95 |
+
#include <ATen/ops/ceil_meta_dispatch.h>
|
| 96 |
+
#include <ATen/ops/clamp_meta_dispatch.h>
|
| 97 |
+
#include <ATen/ops/clamp_max_meta_dispatch.h>
|
| 98 |
+
#include <ATen/ops/clamp_min_meta_dispatch.h>
|
| 99 |
+
#include <ATen/ops/copy_meta_dispatch.h>
|
| 100 |
+
#include <ATen/ops/copy_sparse_to_sparse_meta_dispatch.h>
|
| 101 |
+
#include <ATen/ops/copysign_meta_dispatch.h>
|
| 102 |
+
#include <ATen/ops/cos_meta_dispatch.h>
|
| 103 |
+
#include <ATen/ops/cosh_meta_dispatch.h>
|
| 104 |
+
#include <ATen/ops/cumprod_meta_dispatch.h>
|
| 105 |
+
#include <ATen/ops/cumsum_meta_dispatch.h>
|
| 106 |
+
#include <ATen/ops/digamma_meta_dispatch.h>
|
| 107 |
+
#include <ATen/ops/div_meta_dispatch.h>
|
| 108 |
+
#include <ATen/ops/elu_meta_dispatch.h>
|
| 109 |
+
#include <ATen/ops/elu_backward_meta_dispatch.h>
|
| 110 |
+
#include <ATen/ops/embedding_renorm_meta_dispatch.h>
|
| 111 |
+
#include <ATen/ops/empty_meta_dispatch.h>
|
| 112 |
+
#include <ATen/ops/empty_strided_meta_dispatch.h>
|
| 113 |
+
#include <ATen/ops/eq_meta_dispatch.h>
|
| 114 |
+
#include <ATen/ops/erf_meta_dispatch.h>
|
| 115 |
+
#include <ATen/ops/erfc_meta_dispatch.h>
|
| 116 |
+
#include <ATen/ops/erfinv_meta_dispatch.h>
|
| 117 |
+
#include <ATen/ops/exp_meta_dispatch.h>
|
| 118 |
+
#include <ATen/ops/exp2_meta_dispatch.h>
|
| 119 |
+
#include <ATen/ops/expm1_meta_dispatch.h>
|
| 120 |
+
#include <ATen/ops/exponential_meta_dispatch.h>
|
| 121 |
+
#include <ATen/ops/eye_meta_dispatch.h>
|
| 122 |
+
#include <ATen/ops/fill_meta_dispatch.h>
|
| 123 |
+
#include <ATen/ops/floor_meta_dispatch.h>
|
| 124 |
+
#include <ATen/ops/floor_divide_meta_dispatch.h>
|
| 125 |
+
#include <ATen/ops/fmax_meta_dispatch.h>
|
| 126 |
+
#include <ATen/ops/fmin_meta_dispatch.h>
|
| 127 |
+
#include <ATen/ops/fmod_meta_dispatch.h>
|
| 128 |
+
#include <ATen/ops/frac_meta_dispatch.h>
|
| 129 |
+
#include <ATen/ops/fractional_max_pool2d_meta_dispatch.h>
|
| 130 |
+
#include <ATen/ops/fractional_max_pool2d_backward_meta_dispatch.h>
|
| 131 |
+
#include <ATen/ops/fractional_max_pool3d_meta_dispatch.h>
|
| 132 |
+
#include <ATen/ops/gather_meta_dispatch.h>
|
| 133 |
+
#include <ATen/ops/gcd_meta_dispatch.h>
|
| 134 |
+
#include <ATen/ops/ge_meta_dispatch.h>
|
| 135 |
+
#include <ATen/ops/gelu_meta_dispatch.h>
|
| 136 |
+
#include <ATen/ops/gelu_backward_meta_dispatch.h>
|
| 137 |
+
#include <ATen/ops/geometric_meta_dispatch.h>
|
| 138 |
+
#include <ATen/ops/glu_meta_dispatch.h>
|
| 139 |
+
#include <ATen/ops/gt_meta_dispatch.h>
|
| 140 |
+
#include <ATen/ops/hardshrink_meta_dispatch.h>
|
| 141 |
+
#include <ATen/ops/hardshrink_backward_meta_dispatch.h>
|
| 142 |
+
#include <ATen/ops/hardsigmoid_meta_dispatch.h>
|
| 143 |
+
#include <ATen/ops/hardsigmoid_backward_meta_dispatch.h>
|
| 144 |
+
#include <ATen/ops/hardswish_meta_dispatch.h>
|
| 145 |
+
#include <ATen/ops/hardtanh_meta_dispatch.h>
|
| 146 |
+
#include <ATen/ops/heaviside_meta_dispatch.h>
|
| 147 |
+
#include <ATen/ops/hypot_meta_dispatch.h>
|
| 148 |
+
#include <ATen/ops/i0_meta_dispatch.h>
|
| 149 |
+
#include <ATen/ops/igamma_meta_dispatch.h>
|
| 150 |
+
#include <ATen/ops/igammac_meta_dispatch.h>
|
| 151 |
+
#include <ATen/ops/index_meta_dispatch.h>
|
| 152 |
+
#include <ATen/ops/index_add_meta_dispatch.h>
|
| 153 |
+
#include <ATen/ops/index_copy_meta_dispatch.h>
|
| 154 |
+
#include <ATen/ops/index_fill_meta_dispatch.h>
|
| 155 |
+
#include <ATen/ops/index_reduce_meta_dispatch.h>
|
| 156 |
+
#include <ATen/ops/isin_meta_dispatch.h>
|
| 157 |
+
#include <ATen/ops/isneginf_meta_dispatch.h>
|
| 158 |
+
#include <ATen/ops/isposinf_meta_dispatch.h>
|
| 159 |
+
#include <ATen/ops/lcm_meta_dispatch.h>
|
| 160 |
+
#include <ATen/ops/le_meta_dispatch.h>
|
| 161 |
+
#include <ATen/ops/leaky_relu_meta_dispatch.h>
|
| 162 |
+
#include <ATen/ops/leaky_relu_backward_meta_dispatch.h>
|
| 163 |
+
#include <ATen/ops/lerp_meta_dispatch.h>
|
| 164 |
+
#include <ATen/ops/lgamma_meta_dispatch.h>
|
| 165 |
+
#include <ATen/ops/linalg_cholesky_ex_meta_dispatch.h>
|
| 166 |
+
#include <ATen/ops/linalg_cross_meta_dispatch.h>
|
| 167 |
+
#include <ATen/ops/linalg_inv_ex_meta_dispatch.h>
|
| 168 |
+
#include <ATen/ops/linalg_ldl_factor_ex_meta_dispatch.h>
|
| 169 |
+
#include <ATen/ops/linalg_ldl_solve_meta_dispatch.h>
|
| 170 |
+
#include <ATen/ops/linalg_lu_meta_dispatch.h>
|
| 171 |
+
#include <ATen/ops/linalg_lu_factor_ex_meta_dispatch.h>
|
| 172 |
+
#include <ATen/ops/linalg_lu_solve_meta_dispatch.h>
|
| 173 |
+
#include <ATen/ops/linalg_qr_meta_dispatch.h>
|
| 174 |
+
#include <ATen/ops/linalg_vector_norm_meta_dispatch.h>
|
| 175 |
+
#include <ATen/ops/linspace_meta_dispatch.h>
|
| 176 |
+
#include <ATen/ops/log_meta_dispatch.h>
|
| 177 |
+
#include <ATen/ops/log10_meta_dispatch.h>
|
| 178 |
+
#include <ATen/ops/log1p_meta_dispatch.h>
|
| 179 |
+
#include <ATen/ops/log2_meta_dispatch.h>
|
| 180 |
+
#include <ATen/ops/log_normal_meta_dispatch.h>
|
| 181 |
+
#include <ATen/ops/logaddexp_meta_dispatch.h>
|
| 182 |
+
#include <ATen/ops/logaddexp2_meta_dispatch.h>
|
| 183 |
+
#include <ATen/ops/logit_meta_dispatch.h>
|
| 184 |
+
#include <ATen/ops/logit_backward_meta_dispatch.h>
|
| 185 |
+
#include <ATen/ops/logspace_meta_dispatch.h>
|
| 186 |
+
#include <ATen/ops/lshift_meta_dispatch.h>
|
| 187 |
+
#include <ATen/ops/lt_meta_dispatch.h>
|
| 188 |
+
#include <ATen/ops/lu_unpack_meta_dispatch.h>
|
| 189 |
+
#include <ATen/ops/masked_fill_meta_dispatch.h>
|
| 190 |
+
#include <ATen/ops/masked_scatter_meta_dispatch.h>
|
| 191 |
+
#include <ATen/ops/max_meta_dispatch.h>
|
| 192 |
+
#include <ATen/ops/max_pool2d_with_indices_meta_dispatch.h>
|
| 193 |
+
#include <ATen/ops/max_pool2d_with_indices_backward_meta_dispatch.h>
|
| 194 |
+
#include <ATen/ops/maximum_meta_dispatch.h>
|
| 195 |
+
#include <ATen/ops/mean_meta_dispatch.h>
|
| 196 |
+
#include <ATen/ops/min_meta_dispatch.h>
|
| 197 |
+
#include <ATen/ops/minimum_meta_dispatch.h>
|
| 198 |
+
#include <ATen/ops/mish_meta_dispatch.h>
|
| 199 |
+
#include <ATen/ops/mm_meta_dispatch.h>
|
| 200 |
+
#include <ATen/ops/mse_loss_meta_dispatch.h>
|
| 201 |
+
#include <ATen/ops/mul_meta_dispatch.h>
|
| 202 |
+
#include <ATen/ops/ne_meta_dispatch.h>
|
| 203 |
+
#include <ATen/ops/neg_meta_dispatch.h>
|
| 204 |
+
#include <ATen/ops/nextafter_meta_dispatch.h>
|
| 205 |
+
#include <ATen/ops/nll_loss_backward_meta_dispatch.h>
|
| 206 |
+
#include <ATen/ops/nll_loss_forward_meta_dispatch.h>
|
| 207 |
+
#include <ATen/ops/norm_meta_dispatch.h>
|
| 208 |
+
#include <ATen/ops/normal_meta_dispatch.h>
|
| 209 |
+
#include <ATen/ops/polygamma_meta_dispatch.h>
|
| 210 |
+
#include <ATen/ops/pow_meta_dispatch.h>
|
| 211 |
+
#include <ATen/ops/prod_meta_dispatch.h>
|
| 212 |
+
#include <ATen/ops/put_meta_dispatch.h>
|
| 213 |
+
#include <ATen/ops/random_meta_dispatch.h>
|
| 214 |
+
#include <ATen/ops/range_meta_dispatch.h>
|
| 215 |
+
#include <ATen/ops/reciprocal_meta_dispatch.h>
|
| 216 |
+
#include <ATen/ops/reflection_pad1d_meta_dispatch.h>
|
| 217 |
+
#include <ATen/ops/reflection_pad1d_backward_meta_dispatch.h>
|
| 218 |
+
#include <ATen/ops/reflection_pad3d_meta_dispatch.h>
|
| 219 |
+
#include <ATen/ops/reflection_pad3d_backward_meta_dispatch.h>
|
| 220 |
+
#include <ATen/ops/relu_meta_dispatch.h>
|
| 221 |
+
#include <ATen/ops/remainder_meta_dispatch.h>
|
| 222 |
+
#include <ATen/ops/renorm_meta_dispatch.h>
|
| 223 |
+
#include <ATen/ops/replication_pad1d_meta_dispatch.h>
|
| 224 |
+
#include <ATen/ops/replication_pad1d_backward_meta_dispatch.h>
|
| 225 |
+
#include <ATen/ops/replication_pad2d_meta_dispatch.h>
|
| 226 |
+
#include <ATen/ops/replication_pad3d_meta_dispatch.h>
|
| 227 |
+
#include <ATen/ops/resize_meta_dispatch.h>
|
| 228 |
+
#include <ATen/ops/resize_as_sparse_meta_dispatch.h>
|
| 229 |
+
#include <ATen/ops/round_meta_dispatch.h>
|
| 230 |
+
#include <ATen/ops/rrelu_with_noise_meta_dispatch.h>
|
| 231 |
+
#include <ATen/ops/rshift_meta_dispatch.h>
|
| 232 |
+
#include <ATen/ops/rsqrt_meta_dispatch.h>
|
| 233 |
+
#include <ATen/ops/scatter_meta_dispatch.h>
|
| 234 |
+
#include <ATen/ops/scatter_add_meta_dispatch.h>
|
| 235 |
+
#include <ATen/ops/scatter_reduce_meta_dispatch.h>
|
| 236 |
+
#include <ATen/ops/set_meta_dispatch.h>
|
| 237 |
+
#include <ATen/ops/sgn_meta_dispatch.h>
|
| 238 |
+
#include <ATen/ops/sigmoid_meta_dispatch.h>
|
| 239 |
+
#include <ATen/ops/sigmoid_backward_meta_dispatch.h>
|
| 240 |
+
#include <ATen/ops/sign_meta_dispatch.h>
|
| 241 |
+
#include <ATen/ops/signbit_meta_dispatch.h>
|
| 242 |
+
#include <ATen/ops/silu_meta_dispatch.h>
|
| 243 |
+
#include <ATen/ops/silu_backward_meta_dispatch.h>
|
| 244 |
+
#include <ATen/ops/sin_meta_dispatch.h>
|
| 245 |
+
#include <ATen/ops/sinc_meta_dispatch.h>
|
| 246 |
+
#include <ATen/ops/sinh_meta_dispatch.h>
|
| 247 |
+
#include <ATen/ops/slow_conv_transpose2d_meta_dispatch.h>
|
| 248 |
+
#include <ATen/ops/smooth_l1_loss_meta_dispatch.h>
|
| 249 |
+
#include <ATen/ops/softplus_meta_dispatch.h>
|
| 250 |
+
#include <ATen/ops/softplus_backward_meta_dispatch.h>
|
| 251 |
+
#include <ATen/ops/softshrink_meta_dispatch.h>
|
| 252 |
+
#include <ATen/ops/softshrink_backward_meta_dispatch.h>
|
| 253 |
+
#include <ATen/ops/sort_meta_dispatch.h>
|
| 254 |
+
#include <ATen/ops/sparse_resize_meta_dispatch.h>
|
| 255 |
+
#include <ATen/ops/sparse_resize_and_clear_meta_dispatch.h>
|
| 256 |
+
#include <ATen/ops/special_airy_ai_meta_dispatch.h>
|
| 257 |
+
#include <ATen/ops/special_bessel_j0_meta_dispatch.h>
|
| 258 |
+
#include <ATen/ops/special_bessel_j1_meta_dispatch.h>
|
| 259 |
+
#include <ATen/ops/special_bessel_y0_meta_dispatch.h>
|
| 260 |
+
#include <ATen/ops/special_bessel_y1_meta_dispatch.h>
|
| 261 |
+
#include <ATen/ops/special_chebyshev_polynomial_t_meta_dispatch.h>
|
| 262 |
+
#include <ATen/ops/special_chebyshev_polynomial_u_meta_dispatch.h>
|
| 263 |
+
#include <ATen/ops/special_chebyshev_polynomial_v_meta_dispatch.h>
|
| 264 |
+
#include <ATen/ops/special_chebyshev_polynomial_w_meta_dispatch.h>
|
| 265 |
+
#include <ATen/ops/special_entr_meta_dispatch.h>
|
| 266 |
+
#include <ATen/ops/special_erfcx_meta_dispatch.h>
|
| 267 |
+
#include <ATen/ops/special_hermite_polynomial_h_meta_dispatch.h>
|
| 268 |
+
#include <ATen/ops/special_hermite_polynomial_he_meta_dispatch.h>
|
| 269 |
+
#include <ATen/ops/special_i0e_meta_dispatch.h>
|
| 270 |
+
#include <ATen/ops/special_i1_meta_dispatch.h>
|
| 271 |
+
#include <ATen/ops/special_i1e_meta_dispatch.h>
|
| 272 |
+
#include <ATen/ops/special_laguerre_polynomial_l_meta_dispatch.h>
|
| 273 |
+
#include <ATen/ops/special_legendre_polynomial_p_meta_dispatch.h>
|
| 274 |
+
#include <ATen/ops/special_log_ndtr_meta_dispatch.h>
|
| 275 |
+
#include <ATen/ops/special_modified_bessel_i0_meta_dispatch.h>
|
| 276 |
+
#include <ATen/ops/special_modified_bessel_i1_meta_dispatch.h>
|
| 277 |
+
#include <ATen/ops/special_modified_bessel_k0_meta_dispatch.h>
|
| 278 |
+
#include <ATen/ops/special_modified_bessel_k1_meta_dispatch.h>
|
| 279 |
+
#include <ATen/ops/special_ndtri_meta_dispatch.h>
|
| 280 |
+
#include <ATen/ops/special_scaled_modified_bessel_k0_meta_dispatch.h>
|
| 281 |
+
#include <ATen/ops/special_scaled_modified_bessel_k1_meta_dispatch.h>
|
| 282 |
+
#include <ATen/ops/special_shifted_chebyshev_polynomial_t_meta_dispatch.h>
|
| 283 |
+
#include <ATen/ops/special_shifted_chebyshev_polynomial_u_meta_dispatch.h>
|
| 284 |
+
#include <ATen/ops/special_shifted_chebyshev_polynomial_v_meta_dispatch.h>
|
| 285 |
+
#include <ATen/ops/special_shifted_chebyshev_polynomial_w_meta_dispatch.h>
|
| 286 |
+
#include <ATen/ops/special_spherical_bessel_j0_meta_dispatch.h>
|
| 287 |
+
#include <ATen/ops/special_xlog1py_meta_dispatch.h>
|
| 288 |
+
#include <ATen/ops/special_zeta_meta_dispatch.h>
|
| 289 |
+
#include <ATen/ops/sqrt_meta_dispatch.h>
|
| 290 |
+
#include <ATen/ops/sub_meta_dispatch.h>
|
| 291 |
+
#include <ATen/ops/sum_meta_dispatch.h>
|
| 292 |
+
#include <ATen/ops/tan_meta_dispatch.h>
|
| 293 |
+
#include <ATen/ops/tanh_meta_dispatch.h>
|
| 294 |
+
#include <ATen/ops/tanh_backward_meta_dispatch.h>
|
| 295 |
+
#include <ATen/ops/threshold_meta_dispatch.h>
|
| 296 |
+
#include <ATen/ops/threshold_backward_meta_dispatch.h>
|
| 297 |
+
#include <ATen/ops/topk_meta_dispatch.h>
|
| 298 |
+
#include <ATen/ops/triangular_solve_meta_dispatch.h>
|
| 299 |
+
#include <ATen/ops/tril_meta_dispatch.h>
|
| 300 |
+
#include <ATen/ops/triu_meta_dispatch.h>
|
| 301 |
+
#include <ATen/ops/trunc_meta_dispatch.h>
|
| 302 |
+
#include <ATen/ops/unfold_meta_dispatch.h>
|
| 303 |
+
#include <ATen/ops/uniform_meta_dispatch.h>
|
| 304 |
+
#include <ATen/ops/upsample_bicubic2d_meta_dispatch.h>
|
| 305 |
+
#include <ATen/ops/upsample_bicubic2d_backward_meta_dispatch.h>
|
| 306 |
+
#include <ATen/ops/upsample_bilinear2d_meta_dispatch.h>
|
| 307 |
+
#include <ATen/ops/upsample_bilinear2d_backward_meta_dispatch.h>
|
| 308 |
+
#include <ATen/ops/upsample_linear1d_meta_dispatch.h>
|
| 309 |
+
#include <ATen/ops/upsample_linear1d_backward_meta_dispatch.h>
|
| 310 |
+
#include <ATen/ops/upsample_nearest1d_meta_dispatch.h>
|
| 311 |
+
#include <ATen/ops/upsample_nearest1d_backward_meta_dispatch.h>
|
| 312 |
+
#include <ATen/ops/upsample_nearest2d_meta_dispatch.h>
|
| 313 |
+
#include <ATen/ops/upsample_nearest2d_backward_meta_dispatch.h>
|
| 314 |
+
#include <ATen/ops/upsample_nearest3d_meta_dispatch.h>
|
| 315 |
+
#include <ATen/ops/upsample_nearest3d_backward_meta_dispatch.h>
|
| 316 |
+
#include <ATen/ops/upsample_trilinear3d_meta_dispatch.h>
|
| 317 |
+
#include <ATen/ops/upsample_trilinear3d_backward_meta_dispatch.h>
|
| 318 |
+
#include <ATen/ops/view_meta_dispatch.h>
|
| 319 |
+
#include <ATen/ops/view_as_complex_meta_dispatch.h>
|
| 320 |
+
#include <ATen/ops/view_as_real_meta_dispatch.h>
|
| 321 |
+
#include <ATen/ops/xlogy_meta_dispatch.h>
|
| 322 |
+
#include <ATen/ops/zero_meta_dispatch.h>
|
| 323 |
+
|
| 324 |
+
|
| 325 |
+
|
.venv/lib/python3.11/site-packages/torch/include/ATen/PTThreadPool.h
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/Parallel.h>
|
| 4 |
+
#include <c10/core/thread_pool.h>
|
| 5 |
+
|
| 6 |
+
namespace at {
|
| 7 |
+
|
| 8 |
+
class TORCH_API PTThreadPool : public c10::ThreadPool {
|
| 9 |
+
public:
|
| 10 |
+
explicit PTThreadPool(int pool_size, int numa_node_id = -1)
|
| 11 |
+
: c10::ThreadPool(pool_size, numa_node_id, []() {
|
| 12 |
+
c10::setThreadName("PTThreadPool");
|
| 13 |
+
at::init_num_threads();
|
| 14 |
+
}) {}
|
| 15 |
+
};
|
| 16 |
+
|
| 17 |
+
} // namespace at
|
.venv/lib/python3.11/site-packages/torch/include/ATen/SequenceNumber.h
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <c10/macros/Export.h>
|
| 4 |
+
#include <cstdint>
|
| 5 |
+
|
| 6 |
+
// A simple thread local enumeration, used to link forward and backward pass
|
| 7 |
+
// ops and is used by autograd and observers framework
|
| 8 |
+
namespace at::sequence_number {
|
| 9 |
+
|
| 10 |
+
TORCH_API uint64_t peek();
|
| 11 |
+
TORCH_API uint64_t get_and_increment();
|
| 12 |
+
|
| 13 |
+
} // namespace at::sequence_number
|
.venv/lib/python3.11/site-packages/torch/include/ATen/ThreadLocalPythonObjects.h
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <c10/core/SafePyObject.h>
|
| 4 |
+
#include <c10/macros/Macros.h>
|
| 5 |
+
#include <unordered_map>
|
| 6 |
+
|
| 7 |
+
namespace at::impl {
|
| 8 |
+
|
| 9 |
+
struct TORCH_API ThreadLocalPythonObjects {
|
| 10 |
+
static void set(const std::string& key, std::shared_ptr<SafePyObject> value);
|
| 11 |
+
static const std::shared_ptr<SafePyObject>& get(const std::string& key);
|
| 12 |
+
static bool contains(const std::string& key);
|
| 13 |
+
|
| 14 |
+
static const ThreadLocalPythonObjects& get_state();
|
| 15 |
+
static void set_state(ThreadLocalPythonObjects state);
|
| 16 |
+
|
| 17 |
+
private:
|
| 18 |
+
std::unordered_map<std::string, std::shared_ptr<c10::SafePyObject>> obj_dict_;
|
| 19 |
+
};
|
| 20 |
+
|
| 21 |
+
} // namespace at::impl
|
.venv/lib/python3.11/site-packages/torch/include/ATen/ThreadLocalState.h
ADDED
|
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <c10/core/InferenceMode.h>
|
| 4 |
+
#include <c10/core/impl/LocalDispatchKeySet.h>
|
| 5 |
+
#include <c10/util/Exception.h>
|
| 6 |
+
#include <c10/util/ThreadLocalDebugInfo.h>
|
| 7 |
+
|
| 8 |
+
#include <ATen/FuncTorchTLS.h>
|
| 9 |
+
#include <ATen/PythonTorchFunctionTLS.h>
|
| 10 |
+
#include <ATen/SavedTensorHooks.h>
|
| 11 |
+
#include <ATen/ThreadLocalPythonObjects.h>
|
| 12 |
+
#include <ATen/record_function.h>
|
| 13 |
+
#include <c10/core/impl/PythonDispatcherTLS.h>
|
| 14 |
+
#include <c10/core/impl/TorchDispatchModeTLS.h>
|
| 15 |
+
|
| 16 |
+
namespace at {
|
| 17 |
+
|
| 18 |
+
// Thread local state contains values that are preserved across
|
| 19 |
+
// thread boundaries (e.g. at::launch/JIT fork, autograd).
|
| 20 |
+
// Note at::parallel_for doesn't preserve TLS across thread boundaries.
|
| 21 |
+
class TORCH_API ThreadLocalState {
|
| 22 |
+
public:
|
| 23 |
+
// Saves the thread local variables' values and
|
| 24 |
+
// returns them as a ThreadLocalState
|
| 25 |
+
ThreadLocalState();
|
| 26 |
+
|
| 27 |
+
// set_grad_mode - force the value of the grad mode TLS in
|
| 28 |
+
// the current state object. This is used for example in the
|
| 29 |
+
// autograd engine.
|
| 30 |
+
void set_grad_mode(bool enabled);
|
| 31 |
+
|
| 32 |
+
// set_multithreading_enabled - force the value of the multithreadinmaximum
|
| 33 |
+
// threads TLS in
|
| 34 |
+
// the current state object. This is used for example in the
|
| 35 |
+
// autograd engine.
|
| 36 |
+
void set_multithreading_enabled(bool enabled);
|
| 37 |
+
|
| 38 |
+
// Sets thread local variables in the current thread,
|
| 39 |
+
// according to the thread boundary specified
|
| 40 |
+
static void setThreadLocalState(const ThreadLocalState& state);
|
| 41 |
+
|
| 42 |
+
private:
|
| 43 |
+
c10::impl::LocalDispatchKeySet dispatch_key_;
|
| 44 |
+
|
| 45 |
+
// ThreadLocalDebugInfo does not change after being created
|
| 46 |
+
// with DebugInfoGuard
|
| 47 |
+
std::shared_ptr<c10::ThreadLocalDebugInfo> debug_info_;
|
| 48 |
+
|
| 49 |
+
// RecordFunction TLS
|
| 50 |
+
RecordFunctionTLS rf_tls_;
|
| 51 |
+
|
| 52 |
+
// TLS for out-of-tree functorch
|
| 53 |
+
// See NOTE [functorch TLS in pytorch/pytorch] for why this needs to be a
|
| 54 |
+
// pointer (spoiler alert: it's due to the indirection)
|
| 55 |
+
// This needs to be a shared_ptr instead of a unique_ptr because
|
| 56 |
+
// ThreadLocalState is copy-able and does indeed get copied. Maybe we can
|
| 57 |
+
// consider adding an explicit copy constructor for ThreadLocalState in the
|
| 58 |
+
// future but I didn't want to add one just for this.
|
| 59 |
+
std::shared_ptr<const functorch::FuncTorchTLSBase> functorch_tls_;
|
| 60 |
+
|
| 61 |
+
// TLS for AutogradModes
|
| 62 |
+
AutogradState autograd_tls_;
|
| 63 |
+
|
| 64 |
+
// TLS for enable_torch_dispatch_mode
|
| 65 |
+
c10::impl::TorchDispatchModeTLS torch_dispatch_mode_state_;
|
| 66 |
+
|
| 67 |
+
// TLS for enable_python_dispatcher
|
| 68 |
+
c10::impl::PyInterpreter* python_dispatcher_state_;
|
| 69 |
+
|
| 70 |
+
// TLS for __torch_function__ (mode and disable_torch_function)
|
| 71 |
+
at::impl::PythonTorchFunctionTLS python_torch_function_state_;
|
| 72 |
+
|
| 73 |
+
// TLS for saved tensors default hooks
|
| 74 |
+
at::impl::SavedTensorDefaultHooksTLS saved_tensors_default_hooks_state_;
|
| 75 |
+
|
| 76 |
+
bool functionalization_reapply_views_state_;
|
| 77 |
+
|
| 78 |
+
// TLS for arbitrary python objects that is registered via hooks
|
| 79 |
+
at::impl::ThreadLocalPythonObjects saved_objects_;
|
| 80 |
+
|
| 81 |
+
#if !defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE) && \
|
| 82 |
+
!defined(BUILD_LITE_INTERPRETER)
|
| 83 |
+
// TLS for autocast dtypes
|
| 84 |
+
std::array<at::ScalarType, at::COMPILE_TIME_MAX_DEVICE_TYPES>
|
| 85 |
+
autocast_dtypes_;
|
| 86 |
+
#endif
|
| 87 |
+
|
| 88 |
+
friend class ThreadLocalStateGuard;
|
| 89 |
+
};
|
| 90 |
+
|
| 91 |
+
// Guard to set and reset the thread local state
|
| 92 |
+
class TORCH_API ThreadLocalStateGuard {
|
| 93 |
+
public:
|
| 94 |
+
explicit ThreadLocalStateGuard(const ThreadLocalState& state)
|
| 95 |
+
: prev_state_(ThreadLocalState()) {
|
| 96 |
+
// set the given state across the thread boundary
|
| 97 |
+
ThreadLocalState::setThreadLocalState(state);
|
| 98 |
+
}
|
| 99 |
+
|
| 100 |
+
~ThreadLocalStateGuard() {
|
| 101 |
+
// restore previously set variables
|
| 102 |
+
ThreadLocalState::setThreadLocalState(prev_state_);
|
| 103 |
+
}
|
| 104 |
+
|
| 105 |
+
private:
|
| 106 |
+
// NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
|
| 107 |
+
const ThreadLocalState prev_state_;
|
| 108 |
+
};
|
| 109 |
+
|
| 110 |
+
template <typename T>
|
| 111 |
+
auto wrapPropagateTLSState(T callback) {
|
| 112 |
+
return [tls_state = ThreadLocalState(),
|
| 113 |
+
callback = std::move(callback)](auto&&... args) {
|
| 114 |
+
ThreadLocalStateGuard g(tls_state);
|
| 115 |
+
// Propagate value returned by callback().
|
| 116 |
+
return callback(std::forward<decltype(args)>(args)...);
|
| 117 |
+
};
|
| 118 |
+
}
|
| 119 |
+
|
| 120 |
+
} // namespace at
|
.venv/lib/python3.11/site-packages/torch/include/ATen/Version.h
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <ATen/Context.h>
|
| 2 |
+
|
| 3 |
+
namespace at {
|
| 4 |
+
|
| 5 |
+
/// Returns a detailed string describing the configuration PyTorch.
|
| 6 |
+
TORCH_API std::string show_config();
|
| 7 |
+
|
| 8 |
+
TORCH_API std::string get_mkl_version();
|
| 9 |
+
|
| 10 |
+
TORCH_API std::string get_mkldnn_version();
|
| 11 |
+
|
| 12 |
+
TORCH_API std::string get_openmp_version();
|
| 13 |
+
|
| 14 |
+
TORCH_API std::string get_cxx_flags();
|
| 15 |
+
|
| 16 |
+
TORCH_API std::string get_cpu_capability();
|
| 17 |
+
|
| 18 |
+
} // namespace at
|
.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/AsmUtils.cuh
ADDED
|
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
#include <cstdint>
|
| 3 |
+
|
| 4 |
+
// Collection of direct PTX functions
|
| 5 |
+
|
| 6 |
+
namespace at::cuda {
|
| 7 |
+
|
| 8 |
+
template <typename T>
|
| 9 |
+
struct Bitfield {};
|
| 10 |
+
|
| 11 |
+
template <>
|
| 12 |
+
struct Bitfield<unsigned int> {
|
| 13 |
+
static __device__ __host__ __forceinline__
|
| 14 |
+
unsigned int getBitfield(unsigned int val, int pos, int len) {
|
| 15 |
+
#if !defined(__CUDA_ARCH__)
|
| 16 |
+
pos &= 0xff;
|
| 17 |
+
len &= 0xff;
|
| 18 |
+
|
| 19 |
+
unsigned int m = (1u << len) - 1u;
|
| 20 |
+
return (val >> pos) & m;
|
| 21 |
+
#else
|
| 22 |
+
unsigned int ret;
|
| 23 |
+
asm("bfe.u32 %0, %1, %2, %3;" : "=r"(ret) : "r"(val), "r"(pos), "r"(len));
|
| 24 |
+
return ret;
|
| 25 |
+
#endif
|
| 26 |
+
}
|
| 27 |
+
|
| 28 |
+
static __device__ __host__ __forceinline__
|
| 29 |
+
unsigned int setBitfield(unsigned int val, unsigned int toInsert, int pos, int len) {
|
| 30 |
+
#if !defined(__CUDA_ARCH__)
|
| 31 |
+
pos &= 0xff;
|
| 32 |
+
len &= 0xff;
|
| 33 |
+
|
| 34 |
+
unsigned int m = (1u << len) - 1u;
|
| 35 |
+
toInsert &= m;
|
| 36 |
+
toInsert <<= pos;
|
| 37 |
+
m <<= pos;
|
| 38 |
+
|
| 39 |
+
return (val & ~m) | toInsert;
|
| 40 |
+
#else
|
| 41 |
+
unsigned int ret;
|
| 42 |
+
asm("bfi.b32 %0, %1, %2, %3, %4;" :
|
| 43 |
+
"=r"(ret) : "r"(toInsert), "r"(val), "r"(pos), "r"(len));
|
| 44 |
+
return ret;
|
| 45 |
+
#endif
|
| 46 |
+
}
|
| 47 |
+
};
|
| 48 |
+
|
| 49 |
+
template <>
|
| 50 |
+
struct Bitfield<uint64_t> {
|
| 51 |
+
static __device__ __host__ __forceinline__
|
| 52 |
+
uint64_t getBitfield(uint64_t val, int pos, int len) {
|
| 53 |
+
#if !defined(__CUDA_ARCH__)
|
| 54 |
+
pos &= 0xff;
|
| 55 |
+
len &= 0xff;
|
| 56 |
+
|
| 57 |
+
uint64_t m = (1u << len) - 1u;
|
| 58 |
+
return (val >> pos) & m;
|
| 59 |
+
#else
|
| 60 |
+
uint64_t ret;
|
| 61 |
+
asm("bfe.u64 %0, %1, %2, %3;" : "=l"(ret) : "l"(val), "r"(pos), "r"(len));
|
| 62 |
+
return ret;
|
| 63 |
+
#endif
|
| 64 |
+
}
|
| 65 |
+
|
| 66 |
+
static __device__ __host__ __forceinline__
|
| 67 |
+
uint64_t setBitfield(uint64_t val, uint64_t toInsert, int pos, int len) {
|
| 68 |
+
#if !defined(__CUDA_ARCH__)
|
| 69 |
+
pos &= 0xff;
|
| 70 |
+
len &= 0xff;
|
| 71 |
+
|
| 72 |
+
uint64_t m = (1u << len) - 1u;
|
| 73 |
+
toInsert &= m;
|
| 74 |
+
toInsert <<= pos;
|
| 75 |
+
m <<= pos;
|
| 76 |
+
|
| 77 |
+
return (val & ~m) | toInsert;
|
| 78 |
+
#else
|
| 79 |
+
uint64_t ret;
|
| 80 |
+
asm("bfi.b64 %0, %1, %2, %3, %4;" :
|
| 81 |
+
"=l"(ret) : "l"(toInsert), "l"(val), "r"(pos), "r"(len));
|
| 82 |
+
return ret;
|
| 83 |
+
#endif
|
| 84 |
+
}
|
| 85 |
+
};
|
| 86 |
+
|
| 87 |
+
__device__ __forceinline__ int getLaneId() {
|
| 88 |
+
#if defined(USE_ROCM)
|
| 89 |
+
return __lane_id();
|
| 90 |
+
#else
|
| 91 |
+
int laneId;
|
| 92 |
+
asm("mov.s32 %0, %%laneid;" : "=r"(laneId) );
|
| 93 |
+
return laneId;
|
| 94 |
+
#endif
|
| 95 |
+
}
|
| 96 |
+
|
| 97 |
+
#if defined(USE_ROCM)
|
| 98 |
+
__device__ __forceinline__ unsigned long long int getLaneMaskLt() {
|
| 99 |
+
const std::uint64_t m = (1ull << getLaneId()) - 1ull;
|
| 100 |
+
return m;
|
| 101 |
+
}
|
| 102 |
+
#else
|
| 103 |
+
__device__ __forceinline__ unsigned getLaneMaskLt() {
|
| 104 |
+
unsigned mask;
|
| 105 |
+
asm("mov.u32 %0, %%lanemask_lt;" : "=r"(mask));
|
| 106 |
+
return mask;
|
| 107 |
+
}
|
| 108 |
+
#endif
|
| 109 |
+
|
| 110 |
+
#if defined (USE_ROCM)
|
| 111 |
+
__device__ __forceinline__ unsigned long long int getLaneMaskLe() {
|
| 112 |
+
std::uint64_t m = UINT64_MAX >> (sizeof(std::uint64_t) * CHAR_BIT - (getLaneId() + 1));
|
| 113 |
+
return m;
|
| 114 |
+
}
|
| 115 |
+
#else
|
| 116 |
+
__device__ __forceinline__ unsigned getLaneMaskLe() {
|
| 117 |
+
unsigned mask;
|
| 118 |
+
asm("mov.u32 %0, %%lanemask_le;" : "=r"(mask));
|
| 119 |
+
return mask;
|
| 120 |
+
}
|
| 121 |
+
#endif
|
| 122 |
+
|
| 123 |
+
#if defined(USE_ROCM)
|
| 124 |
+
__device__ __forceinline__ unsigned long long int getLaneMaskGt() {
|
| 125 |
+
const std::uint64_t m = getLaneMaskLe();
|
| 126 |
+
return m ? ~m : m;
|
| 127 |
+
}
|
| 128 |
+
#else
|
| 129 |
+
__device__ __forceinline__ unsigned getLaneMaskGt() {
|
| 130 |
+
unsigned mask;
|
| 131 |
+
asm("mov.u32 %0, %%lanemask_gt;" : "=r"(mask));
|
| 132 |
+
return mask;
|
| 133 |
+
}
|
| 134 |
+
#endif
|
| 135 |
+
|
| 136 |
+
#if defined(USE_ROCM)
|
| 137 |
+
__device__ __forceinline__ unsigned long long int getLaneMaskGe() {
|
| 138 |
+
const std::uint64_t m = getLaneMaskLt();
|
| 139 |
+
return ~m;
|
| 140 |
+
}
|
| 141 |
+
#else
|
| 142 |
+
__device__ __forceinline__ unsigned getLaneMaskGe() {
|
| 143 |
+
unsigned mask;
|
| 144 |
+
asm("mov.u32 %0, %%lanemask_ge;" : "=r"(mask));
|
| 145 |
+
return mask;
|
| 146 |
+
}
|
| 147 |
+
#endif
|
| 148 |
+
|
| 149 |
+
} // namespace at::cuda
|
.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/CUDAContextLight.h
ADDED
|
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
// Light-weight version of CUDAContext.h with fewer transitive includes
|
| 3 |
+
|
| 4 |
+
#include <cstdint>
|
| 5 |
+
|
| 6 |
+
#include <cuda_runtime_api.h>
|
| 7 |
+
#include <cusparse.h>
|
| 8 |
+
#include <cublas_v2.h>
|
| 9 |
+
|
| 10 |
+
// cublasLT was introduced in CUDA 10.1 but we enable only for 11.1 that also
|
| 11 |
+
// added bf16 support
|
| 12 |
+
#include <cublasLt.h>
|
| 13 |
+
|
| 14 |
+
#ifdef CUDART_VERSION
|
| 15 |
+
#include <cusolverDn.h>
|
| 16 |
+
#endif
|
| 17 |
+
|
| 18 |
+
#if defined(USE_CUDSS)
|
| 19 |
+
#include <cudss.h>
|
| 20 |
+
#endif
|
| 21 |
+
|
| 22 |
+
#if defined(USE_ROCM)
|
| 23 |
+
#include <hipsolver/hipsolver.h>
|
| 24 |
+
#endif
|
| 25 |
+
|
| 26 |
+
#include <c10/core/Allocator.h>
|
| 27 |
+
#include <c10/cuda/CUDAFunctions.h>
|
| 28 |
+
|
| 29 |
+
namespace c10 {
|
| 30 |
+
struct Allocator;
|
| 31 |
+
}
|
| 32 |
+
|
| 33 |
+
namespace at::cuda {
|
| 34 |
+
|
| 35 |
+
/*
|
| 36 |
+
A common CUDA interface for ATen.
|
| 37 |
+
|
| 38 |
+
This interface is distinct from CUDAHooks, which defines an interface that links
|
| 39 |
+
to both CPU-only and CUDA builds. That interface is intended for runtime
|
| 40 |
+
dispatch and should be used from files that are included in both CPU-only and
|
| 41 |
+
CUDA builds.
|
| 42 |
+
|
| 43 |
+
CUDAContext, on the other hand, should be preferred by files only included in
|
| 44 |
+
CUDA builds. It is intended to expose CUDA functionality in a consistent
|
| 45 |
+
manner.
|
| 46 |
+
|
| 47 |
+
This means there is some overlap between the CUDAContext and CUDAHooks, but
|
| 48 |
+
the choice of which to use is simple: use CUDAContext when in a CUDA-only file,
|
| 49 |
+
use CUDAHooks otherwise.
|
| 50 |
+
|
| 51 |
+
Note that CUDAContext simply defines an interface with no associated class.
|
| 52 |
+
It is expected that the modules whose functions compose this interface will
|
| 53 |
+
manage their own state. There is only a single CUDA context/state.
|
| 54 |
+
*/
|
| 55 |
+
|
| 56 |
+
/**
|
| 57 |
+
* DEPRECATED: use device_count() instead
|
| 58 |
+
*/
|
| 59 |
+
inline int64_t getNumGPUs() {
|
| 60 |
+
return c10::cuda::device_count();
|
| 61 |
+
}
|
| 62 |
+
|
| 63 |
+
/**
|
| 64 |
+
* CUDA is available if we compiled with CUDA, and there are one or more
|
| 65 |
+
* devices. If we compiled with CUDA but there is a driver problem, etc.,
|
| 66 |
+
* this function will report CUDA is not available (rather than raise an error.)
|
| 67 |
+
*/
|
| 68 |
+
inline bool is_available() {
|
| 69 |
+
return c10::cuda::device_count() > 0;
|
| 70 |
+
}
|
| 71 |
+
|
| 72 |
+
TORCH_CUDA_CPP_API cudaDeviceProp* getCurrentDeviceProperties();
|
| 73 |
+
|
| 74 |
+
TORCH_CUDA_CPP_API int warp_size();
|
| 75 |
+
|
| 76 |
+
TORCH_CUDA_CPP_API cudaDeviceProp* getDeviceProperties(c10::DeviceIndex device);
|
| 77 |
+
|
| 78 |
+
TORCH_CUDA_CPP_API bool canDeviceAccessPeer(
|
| 79 |
+
c10::DeviceIndex device,
|
| 80 |
+
c10::DeviceIndex peer_device);
|
| 81 |
+
|
| 82 |
+
TORCH_CUDA_CPP_API c10::Allocator* getCUDADeviceAllocator();
|
| 83 |
+
|
| 84 |
+
/* Handles */
|
| 85 |
+
TORCH_CUDA_CPP_API cusparseHandle_t getCurrentCUDASparseHandle();
|
| 86 |
+
TORCH_CUDA_CPP_API cublasHandle_t getCurrentCUDABlasHandle();
|
| 87 |
+
TORCH_CUDA_CPP_API cublasLtHandle_t getCurrentCUDABlasLtHandle();
|
| 88 |
+
|
| 89 |
+
TORCH_CUDA_CPP_API void clearCublasWorkspaces();
|
| 90 |
+
|
| 91 |
+
#if defined(CUDART_VERSION) || defined(USE_ROCM)
|
| 92 |
+
TORCH_CUDA_CPP_API cusolverDnHandle_t getCurrentCUDASolverDnHandle();
|
| 93 |
+
#endif
|
| 94 |
+
|
| 95 |
+
#if defined(USE_CUDSS)
|
| 96 |
+
TORCH_CUDA_CPP_API cudssHandle_t getCurrentCudssHandle();
|
| 97 |
+
#endif
|
| 98 |
+
|
| 99 |
+
} // namespace at::cuda
|
.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/CUDADataType.h
ADDED
|
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <c10/core/ScalarType.h>
|
| 4 |
+
|
| 5 |
+
#include <cuda.h>
|
| 6 |
+
#include <library_types.h>
|
| 7 |
+
|
| 8 |
+
namespace at::cuda {
|
| 9 |
+
|
| 10 |
+
template <typename scalar_t>
|
| 11 |
+
cudaDataType getCudaDataType() {
|
| 12 |
+
static_assert(false && sizeof(scalar_t), "Cannot convert type to cudaDataType.");
|
| 13 |
+
return {};
|
| 14 |
+
}
|
| 15 |
+
|
| 16 |
+
template<> inline cudaDataType getCudaDataType<at::Half>() {
|
| 17 |
+
return CUDA_R_16F;
|
| 18 |
+
}
|
| 19 |
+
template<> inline cudaDataType getCudaDataType<float>() {
|
| 20 |
+
return CUDA_R_32F;
|
| 21 |
+
}
|
| 22 |
+
template<> inline cudaDataType getCudaDataType<double>() {
|
| 23 |
+
return CUDA_R_64F;
|
| 24 |
+
}
|
| 25 |
+
template<> inline cudaDataType getCudaDataType<c10::complex<c10::Half>>() {
|
| 26 |
+
return CUDA_C_16F;
|
| 27 |
+
}
|
| 28 |
+
template<> inline cudaDataType getCudaDataType<c10::complex<float>>() {
|
| 29 |
+
return CUDA_C_32F;
|
| 30 |
+
}
|
| 31 |
+
template<> inline cudaDataType getCudaDataType<c10::complex<double>>() {
|
| 32 |
+
return CUDA_C_64F;
|
| 33 |
+
}
|
| 34 |
+
|
| 35 |
+
template<> inline cudaDataType getCudaDataType<uint8_t>() {
|
| 36 |
+
return CUDA_R_8U;
|
| 37 |
+
}
|
| 38 |
+
template<> inline cudaDataType getCudaDataType<int8_t>() {
|
| 39 |
+
return CUDA_R_8I;
|
| 40 |
+
}
|
| 41 |
+
template<> inline cudaDataType getCudaDataType<int>() {
|
| 42 |
+
return CUDA_R_32I;
|
| 43 |
+
}
|
| 44 |
+
|
| 45 |
+
template<> inline cudaDataType getCudaDataType<int16_t>() {
|
| 46 |
+
return CUDA_R_16I;
|
| 47 |
+
}
|
| 48 |
+
template<> inline cudaDataType getCudaDataType<int64_t>() {
|
| 49 |
+
return CUDA_R_64I;
|
| 50 |
+
}
|
| 51 |
+
template<> inline cudaDataType getCudaDataType<at::BFloat16>() {
|
| 52 |
+
return CUDA_R_16BF;
|
| 53 |
+
}
|
| 54 |
+
|
| 55 |
+
inline cudaDataType ScalarTypeToCudaDataType(const c10::ScalarType& scalar_type) {
|
| 56 |
+
switch (scalar_type) {
|
| 57 |
+
case c10::ScalarType::Byte:
|
| 58 |
+
return CUDA_R_8U;
|
| 59 |
+
case c10::ScalarType::Char:
|
| 60 |
+
return CUDA_R_8I;
|
| 61 |
+
case c10::ScalarType::Int:
|
| 62 |
+
return CUDA_R_32I;
|
| 63 |
+
case c10::ScalarType::Half:
|
| 64 |
+
return CUDA_R_16F;
|
| 65 |
+
case c10::ScalarType::Float:
|
| 66 |
+
return CUDA_R_32F;
|
| 67 |
+
case c10::ScalarType::Double:
|
| 68 |
+
return CUDA_R_64F;
|
| 69 |
+
case c10::ScalarType::ComplexHalf:
|
| 70 |
+
return CUDA_C_16F;
|
| 71 |
+
case c10::ScalarType::ComplexFloat:
|
| 72 |
+
return CUDA_C_32F;
|
| 73 |
+
case c10::ScalarType::ComplexDouble:
|
| 74 |
+
return CUDA_C_64F;
|
| 75 |
+
case c10::ScalarType::Short:
|
| 76 |
+
return CUDA_R_16I;
|
| 77 |
+
case c10::ScalarType::Long:
|
| 78 |
+
return CUDA_R_64I;
|
| 79 |
+
case c10::ScalarType::BFloat16:
|
| 80 |
+
return CUDA_R_16BF;
|
| 81 |
+
#if defined(CUDA_VERSION) && CUDA_VERSION >= 11080
|
| 82 |
+
case c10::ScalarType::Float8_e4m3fn:
|
| 83 |
+
return CUDA_R_8F_E4M3;
|
| 84 |
+
case c10::ScalarType::Float8_e5m2:
|
| 85 |
+
return CUDA_R_8F_E5M2;
|
| 86 |
+
#endif
|
| 87 |
+
#if defined(USE_ROCM)
|
| 88 |
+
#if defined(HIP_NEW_TYPE_ENUMS)
|
| 89 |
+
case c10::ScalarType::Float8_e4m3fnuz:
|
| 90 |
+
return HIP_R_8F_E4M3_FNUZ;
|
| 91 |
+
case c10::ScalarType::Float8_e5m2fnuz:
|
| 92 |
+
return HIP_R_8F_E5M2_FNUZ;
|
| 93 |
+
#else
|
| 94 |
+
case c10::ScalarType::Float8_e4m3fnuz:
|
| 95 |
+
return static_cast<hipDataType>(1000);
|
| 96 |
+
case c10::ScalarType::Float8_e5m2fnuz:
|
| 97 |
+
return static_cast<hipDataType>(1001);
|
| 98 |
+
#endif
|
| 99 |
+
#endif
|
| 100 |
+
default:
|
| 101 |
+
TORCH_INTERNAL_ASSERT(false, "Cannot convert ScalarType ", scalar_type, " to cudaDataType.")
|
| 102 |
+
}
|
| 103 |
+
}
|
| 104 |
+
|
| 105 |
+
} // namespace at::cuda
|
.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/CUDADevice.h
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/cuda/Exceptions.h>
|
| 4 |
+
|
| 5 |
+
#include <cuda.h>
|
| 6 |
+
#include <cuda_runtime.h>
|
| 7 |
+
|
| 8 |
+
namespace at::cuda {
|
| 9 |
+
|
| 10 |
+
inline Device getDeviceFromPtr(void* ptr) {
|
| 11 |
+
cudaPointerAttributes attr{};
|
| 12 |
+
|
| 13 |
+
AT_CUDA_CHECK(cudaPointerGetAttributes(&attr, ptr));
|
| 14 |
+
|
| 15 |
+
#if !defined(USE_ROCM)
|
| 16 |
+
TORCH_CHECK(attr.type != cudaMemoryTypeUnregistered,
|
| 17 |
+
"The specified pointer resides on host memory and is not registered with any CUDA device.");
|
| 18 |
+
#endif
|
| 19 |
+
|
| 20 |
+
return {c10::DeviceType::CUDA, static_cast<DeviceIndex>(attr.device)};
|
| 21 |
+
}
|
| 22 |
+
|
| 23 |
+
} // namespace at::cuda
|
.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/CUDAUtils.h
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/cuda/CUDAContext.h>
|
| 4 |
+
|
| 5 |
+
namespace at::cuda {
|
| 6 |
+
|
| 7 |
+
// Check if every tensor in a list of tensors matches the current
|
| 8 |
+
// device.
|
| 9 |
+
inline bool check_device(ArrayRef<Tensor> ts) {
|
| 10 |
+
if (ts.empty()) {
|
| 11 |
+
return true;
|
| 12 |
+
}
|
| 13 |
+
Device curDevice = Device(kCUDA, current_device());
|
| 14 |
+
for (const Tensor& t : ts) {
|
| 15 |
+
if (t.device() != curDevice) return false;
|
| 16 |
+
}
|
| 17 |
+
return true;
|
| 18 |
+
}
|
| 19 |
+
|
| 20 |
+
} // namespace at::cuda
|
.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/NumericLimits.cuh
ADDED
|
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <cuda.h>
|
| 4 |
+
#include <limits.h>
|
| 5 |
+
#include <math.h>
|
| 6 |
+
#include <float.h>
|
| 7 |
+
|
| 8 |
+
// NumericLimits.cuh is a holder for numeric limits definitions of commonly used
|
| 9 |
+
// types. This header is very specific to ROCm HIP and may be removed in the future.
|
| 10 |
+
// This header is derived from the legacy THCNumerics.cuh.
|
| 11 |
+
|
| 12 |
+
// The lower_bound and upper_bound constants are same as lowest and max for
|
| 13 |
+
// integral types, but are -inf and +inf for floating point types. They are
|
| 14 |
+
// useful in implementing min, max, etc.
|
| 15 |
+
|
| 16 |
+
namespace at {
|
| 17 |
+
|
| 18 |
+
template <typename T>
|
| 19 |
+
struct numeric_limits {
|
| 20 |
+
};
|
| 21 |
+
|
| 22 |
+
// WARNING: the following at::numeric_limits definitions are there only to support
|
| 23 |
+
// HIP compilation for the moment. Use std::numeric_limits if you are not
|
| 24 |
+
// compiling for ROCm.
|
| 25 |
+
// from @colesbury: "The functions on numeric_limits aren't marked with
|
| 26 |
+
// __device__ which is why they don't work with ROCm. CUDA allows them
|
| 27 |
+
// because they're constexpr."
|
| 28 |
+
|
| 29 |
+
namespace {
|
| 30 |
+
// ROCm doesn't like INFINITY too.
|
| 31 |
+
constexpr double inf = INFINITY;
|
| 32 |
+
}
|
| 33 |
+
|
| 34 |
+
template <>
|
| 35 |
+
struct numeric_limits<bool> {
|
| 36 |
+
static inline __host__ __device__ bool lowest() { return false; }
|
| 37 |
+
static inline __host__ __device__ bool max() { return true; }
|
| 38 |
+
static inline __host__ __device__ bool lower_bound() { return false; }
|
| 39 |
+
static inline __host__ __device__ bool upper_bound() { return true; }
|
| 40 |
+
};
|
| 41 |
+
|
| 42 |
+
template <>
|
| 43 |
+
struct numeric_limits<uint8_t> {
|
| 44 |
+
static inline __host__ __device__ uint8_t lowest() { return 0; }
|
| 45 |
+
static inline __host__ __device__ uint8_t max() { return UINT8_MAX; }
|
| 46 |
+
static inline __host__ __device__ uint8_t lower_bound() { return 0; }
|
| 47 |
+
static inline __host__ __device__ uint8_t upper_bound() { return UINT8_MAX; }
|
| 48 |
+
};
|
| 49 |
+
|
| 50 |
+
template <>
|
| 51 |
+
struct numeric_limits<int8_t> {
|
| 52 |
+
static inline __host__ __device__ int8_t lowest() { return INT8_MIN; }
|
| 53 |
+
static inline __host__ __device__ int8_t max() { return INT8_MAX; }
|
| 54 |
+
static inline __host__ __device__ int8_t lower_bound() { return INT8_MIN; }
|
| 55 |
+
static inline __host__ __device__ int8_t upper_bound() { return INT8_MAX; }
|
| 56 |
+
};
|
| 57 |
+
|
| 58 |
+
template <>
|
| 59 |
+
struct numeric_limits<int16_t> {
|
| 60 |
+
static inline __host__ __device__ int16_t lowest() { return INT16_MIN; }
|
| 61 |
+
static inline __host__ __device__ int16_t max() { return INT16_MAX; }
|
| 62 |
+
static inline __host__ __device__ int16_t lower_bound() { return INT16_MIN; }
|
| 63 |
+
static inline __host__ __device__ int16_t upper_bound() { return INT16_MAX; }
|
| 64 |
+
};
|
| 65 |
+
|
| 66 |
+
template <>
|
| 67 |
+
struct numeric_limits<int32_t> {
|
| 68 |
+
static inline __host__ __device__ int32_t lowest() { return INT32_MIN; }
|
| 69 |
+
static inline __host__ __device__ int32_t max() { return INT32_MAX; }
|
| 70 |
+
static inline __host__ __device__ int32_t lower_bound() { return INT32_MIN; }
|
| 71 |
+
static inline __host__ __device__ int32_t upper_bound() { return INT32_MAX; }
|
| 72 |
+
};
|
| 73 |
+
|
| 74 |
+
template <>
|
| 75 |
+
struct numeric_limits<int64_t> {
|
| 76 |
+
#ifdef _MSC_VER
|
| 77 |
+
static inline __host__ __device__ int64_t lowest() { return _I64_MIN; }
|
| 78 |
+
static inline __host__ __device__ int64_t max() { return _I64_MAX; }
|
| 79 |
+
static inline __host__ __device__ int64_t lower_bound() { return _I64_MIN; }
|
| 80 |
+
static inline __host__ __device__ int64_t upper_bound() { return _I64_MAX; }
|
| 81 |
+
#else
|
| 82 |
+
static inline __host__ __device__ int64_t lowest() { return INT64_MIN; }
|
| 83 |
+
static inline __host__ __device__ int64_t max() { return INT64_MAX; }
|
| 84 |
+
static inline __host__ __device__ int64_t lower_bound() { return INT64_MIN; }
|
| 85 |
+
static inline __host__ __device__ int64_t upper_bound() { return INT64_MAX; }
|
| 86 |
+
#endif
|
| 87 |
+
};
|
| 88 |
+
|
| 89 |
+
template <>
|
| 90 |
+
struct numeric_limits<at::Half> {
|
| 91 |
+
static inline __host__ __device__ at::Half lowest() { return at::Half(0xFBFF, at::Half::from_bits()); }
|
| 92 |
+
static inline __host__ __device__ at::Half max() { return at::Half(0x7BFF, at::Half::from_bits()); }
|
| 93 |
+
static inline __host__ __device__ at::Half lower_bound() { return at::Half(0xFC00, at::Half::from_bits()); }
|
| 94 |
+
static inline __host__ __device__ at::Half upper_bound() { return at::Half(0x7C00, at::Half::from_bits()); }
|
| 95 |
+
};
|
| 96 |
+
|
| 97 |
+
template <>
|
| 98 |
+
struct numeric_limits<at::BFloat16> {
|
| 99 |
+
static inline __host__ __device__ at::BFloat16 lowest() { return at::BFloat16(0xFF7F, at::BFloat16::from_bits()); }
|
| 100 |
+
static inline __host__ __device__ at::BFloat16 max() { return at::BFloat16(0x7F7F, at::BFloat16::from_bits()); }
|
| 101 |
+
static inline __host__ __device__ at::BFloat16 lower_bound() { return at::BFloat16(0xFF80, at::BFloat16::from_bits()); }
|
| 102 |
+
static inline __host__ __device__ at::BFloat16 upper_bound() { return at::BFloat16(0x7F80, at::BFloat16::from_bits()); }
|
| 103 |
+
};
|
| 104 |
+
|
| 105 |
+
template <>
|
| 106 |
+
struct numeric_limits<float> {
|
| 107 |
+
static inline __host__ __device__ float lowest() { return -FLT_MAX; }
|
| 108 |
+
static inline __host__ __device__ float max() { return FLT_MAX; }
|
| 109 |
+
static inline __host__ __device__ float lower_bound() { return -static_cast<float>(inf); }
|
| 110 |
+
static inline __host__ __device__ float upper_bound() { return static_cast<float>(inf); }
|
| 111 |
+
};
|
| 112 |
+
|
| 113 |
+
template <>
|
| 114 |
+
struct numeric_limits<double> {
|
| 115 |
+
static inline __host__ __device__ double lowest() { return -DBL_MAX; }
|
| 116 |
+
static inline __host__ __device__ double max() { return DBL_MAX; }
|
| 117 |
+
static inline __host__ __device__ double lower_bound() { return -inf; }
|
| 118 |
+
static inline __host__ __device__ double upper_bound() { return inf; }
|
| 119 |
+
};
|
| 120 |
+
|
| 121 |
+
} // namespace at
|
.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/PeerToPeerAccess.h
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <c10/macros/Macros.h>
|
| 2 |
+
#include <cstdint>
|
| 3 |
+
|
| 4 |
+
namespace at::cuda {
|
| 5 |
+
namespace detail {
|
| 6 |
+
void init_p2p_access_cache(int64_t num_devices);
|
| 7 |
+
}
|
| 8 |
+
|
| 9 |
+
TORCH_CUDA_CPP_API bool get_p2p_access(int source_dev, int dest_dev);
|
| 10 |
+
|
| 11 |
+
} // namespace at::cuda
|
.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/Sleep.h
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
#include <c10/macros/Export.h>
|
| 3 |
+
#include <cstdint>
|
| 4 |
+
|
| 5 |
+
namespace at::cuda {
|
| 6 |
+
|
| 7 |
+
// enqueues a kernel that spins for the specified number of cycles
|
| 8 |
+
TORCH_CUDA_CU_API void sleep(int64_t cycles);
|
| 9 |
+
|
| 10 |
+
// flushes instruction cache for ROCm; no-op for CUDA
|
| 11 |
+
TORCH_CUDA_CU_API void flush_icache();
|
| 12 |
+
|
| 13 |
+
} // namespace at::cuda
|
.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/cub_definitions.cuh
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#if !defined(USE_ROCM)
|
| 4 |
+
#include <cuda.h> // for CUDA_VERSION
|
| 5 |
+
#endif
|
| 6 |
+
|
| 7 |
+
#if !defined(USE_ROCM)
|
| 8 |
+
#include <cub/version.cuh>
|
| 9 |
+
#else
|
| 10 |
+
#define CUB_VERSION 0
|
| 11 |
+
#endif
|
| 12 |
+
|
| 13 |
+
// cub sort support for __nv_bfloat16 is added to cub 1.13 in:
|
| 14 |
+
// https://github.com/NVIDIA/cub/pull/306
|
| 15 |
+
#if CUB_VERSION >= 101300
|
| 16 |
+
#define CUB_SUPPORTS_NV_BFLOAT16() true
|
| 17 |
+
#else
|
| 18 |
+
#define CUB_SUPPORTS_NV_BFLOAT16() false
|
| 19 |
+
#endif
|
| 20 |
+
|
| 21 |
+
// cub support for CUB_WRAPPED_NAMESPACE is added to cub 1.13.1 in:
|
| 22 |
+
// https://github.com/NVIDIA/cub/pull/326
|
| 23 |
+
// CUB_WRAPPED_NAMESPACE is defined globally in cmake/Dependencies.cmake
|
| 24 |
+
// starting from CUDA 11.5
|
| 25 |
+
#if defined(CUB_WRAPPED_NAMESPACE) || defined(THRUST_CUB_WRAPPED_NAMESPACE)
|
| 26 |
+
#define USE_GLOBAL_CUB_WRAPPED_NAMESPACE() true
|
| 27 |
+
#else
|
| 28 |
+
#define USE_GLOBAL_CUB_WRAPPED_NAMESPACE() false
|
| 29 |
+
#endif
|
| 30 |
+
|
| 31 |
+
// cub support for UniqueByKey is added to cub 1.16 in:
|
| 32 |
+
// https://github.com/NVIDIA/cub/pull/405
|
| 33 |
+
#if CUB_VERSION >= 101600
|
| 34 |
+
#define CUB_SUPPORTS_UNIQUE_BY_KEY() true
|
| 35 |
+
#else
|
| 36 |
+
#define CUB_SUPPORTS_UNIQUE_BY_KEY() false
|
| 37 |
+
#endif
|
| 38 |
+
|
| 39 |
+
// cub support for scan by key is added to cub 1.15
|
| 40 |
+
// in https://github.com/NVIDIA/cub/pull/376
|
| 41 |
+
#if CUB_VERSION >= 101500
|
| 42 |
+
#define CUB_SUPPORTS_SCAN_BY_KEY() 1
|
| 43 |
+
#else
|
| 44 |
+
#define CUB_SUPPORTS_SCAN_BY_KEY() 0
|
| 45 |
+
#endif
|
| 46 |
+
|
| 47 |
+
// cub support for cub::FutureValue is added to cub 1.15 in:
|
| 48 |
+
// https://github.com/NVIDIA/cub/pull/305
|
| 49 |
+
#if CUB_VERSION >= 101500
|
| 50 |
+
#define CUB_SUPPORTS_FUTURE_VALUE() true
|
| 51 |
+
#else
|
| 52 |
+
#define CUB_SUPPORTS_FUTURE_VALUE() false
|
| 53 |
+
#endif
|
.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/detail/CUDAHooks.h
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/detail/CUDAHooksInterface.h>
|
| 4 |
+
|
| 5 |
+
#include <ATen/Generator.h>
|
| 6 |
+
#include <optional>
|
| 7 |
+
|
| 8 |
+
// TODO: No need to have this whole header, we can just put it all in
|
| 9 |
+
// the cpp file
|
| 10 |
+
|
| 11 |
+
namespace at::cuda::detail {
|
| 12 |
+
|
| 13 |
+
// Set the callback to initialize Magma, which is set by
|
| 14 |
+
// torch_cuda_cu. This indirection is required so magma_init is called
|
| 15 |
+
// in the same library where Magma will be used.
|
| 16 |
+
TORCH_CUDA_CPP_API void set_magma_init_fn(void (*magma_init_fn)());
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
// The real implementation of CUDAHooksInterface
|
| 20 |
+
struct CUDAHooks : public at::CUDAHooksInterface {
|
| 21 |
+
CUDAHooks(at::CUDAHooksArgs) {}
|
| 22 |
+
void initCUDA() const override;
|
| 23 |
+
Device getDeviceFromPtr(void* data) const override;
|
| 24 |
+
bool isPinnedPtr(const void* data) const override;
|
| 25 |
+
const Generator& getDefaultCUDAGenerator(DeviceIndex device_index = -1) const override;
|
| 26 |
+
bool hasCUDA() const override;
|
| 27 |
+
bool hasMAGMA() const override;
|
| 28 |
+
bool hasCuDNN() const override;
|
| 29 |
+
bool hasCuSOLVER() const override;
|
| 30 |
+
bool hasCuBLASLt() const override;
|
| 31 |
+
bool hasROCM() const override;
|
| 32 |
+
const at::cuda::NVRTC& nvrtc() const override;
|
| 33 |
+
DeviceIndex current_device() const override;
|
| 34 |
+
bool hasPrimaryContext(DeviceIndex device_index) const override;
|
| 35 |
+
Allocator* getCUDADeviceAllocator() const override;
|
| 36 |
+
Allocator* getPinnedMemoryAllocator() const override;
|
| 37 |
+
bool compiledWithCuDNN() const override;
|
| 38 |
+
bool compiledWithMIOpen() const override;
|
| 39 |
+
bool supportsDilatedConvolutionWithCuDNN() const override;
|
| 40 |
+
bool supportsDepthwiseConvolutionWithCuDNN() const override;
|
| 41 |
+
bool supportsBFloat16ConvolutionWithCuDNNv8() const override;
|
| 42 |
+
bool hasCUDART() const override;
|
| 43 |
+
long versionCUDART() const override;
|
| 44 |
+
long versionCuDNN() const override;
|
| 45 |
+
std::string showConfig() const override;
|
| 46 |
+
double batchnormMinEpsilonCuDNN() const override;
|
| 47 |
+
int64_t cuFFTGetPlanCacheMaxSize(DeviceIndex device_index) const override;
|
| 48 |
+
void cuFFTSetPlanCacheMaxSize(DeviceIndex device_index, int64_t max_size) const override;
|
| 49 |
+
int64_t cuFFTGetPlanCacheSize(DeviceIndex device_index) const override;
|
| 50 |
+
void cuFFTClearPlanCache(DeviceIndex device_index) const override;
|
| 51 |
+
int getNumGPUs() const override;
|
| 52 |
+
#ifdef USE_ROCM
|
| 53 |
+
bool isGPUArch(DeviceIndex device_index, const std::vector<std::string>& archs) const override;
|
| 54 |
+
#endif
|
| 55 |
+
void deviceSynchronize(DeviceIndex device_index) const override;
|
| 56 |
+
};
|
| 57 |
+
|
| 58 |
+
} // at::cuda::detail
|
.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/detail/DeviceThreadHandles.h
ADDED
|
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Some stateful GPU libraries, such as cuDNN, cuBLAS, use handles to store states.
|
| 2 |
+
// These handles are tied to device, and these libraries requires/recommends not to
|
| 3 |
+
// share handles across host threads.
|
| 4 |
+
//
|
| 5 |
+
// These libraries recommend using one handle per host thread. We may not want to do
|
| 6 |
+
// this because threads are relatively light-weight, but creating and destroying
|
| 7 |
+
// handles is expensive (destroying the handle causes synchronizations). DataParallel,
|
| 8 |
+
// for example, creates new threads for each forward pass.
|
| 9 |
+
//
|
| 10 |
+
// This file implements a handle pool mechanism. The handle pool returns handles on
|
| 11 |
+
// demand as threads request them. If all existing handles in the pool are in use,
|
| 12 |
+
// it creates a new one. As threads terminate, they release handles back into the pool.
|
| 13 |
+
// In this way, the handle pool never creates more handles than the high-water mark of
|
| 14 |
+
// active threads, so it's efficient with DataParallel.
|
| 15 |
+
|
| 16 |
+
#pragma once
|
| 17 |
+
|
| 18 |
+
#include <unordered_map>
|
| 19 |
+
#include <vector>
|
| 20 |
+
#include <utility>
|
| 21 |
+
#include <mutex>
|
| 22 |
+
#include <memory>
|
| 23 |
+
|
| 24 |
+
#include <c10/util/Exception.h>
|
| 25 |
+
|
| 26 |
+
namespace at::cuda { namespace {
|
| 27 |
+
|
| 28 |
+
template <typename Handle_t, void Create(Handle_t *), void Destroy(Handle_t)>
|
| 29 |
+
struct DeviceThreadHandlePool : public std::enable_shared_from_this<DeviceThreadHandlePool<Handle_t, Create, Destroy>> {
|
| 30 |
+
|
| 31 |
+
struct Handle {
|
| 32 |
+
Handle_t handle;
|
| 33 |
+
Handle(bool create = false) : handle(nullptr)
|
| 34 |
+
{
|
| 35 |
+
if(create) Create(&handle);
|
| 36 |
+
}
|
| 37 |
+
// std::vector.emplace() and push_back() may route through temporaries and call
|
| 38 |
+
// copy/move constructors along the way. If this is the case, we don't want
|
| 39 |
+
// the destructors of temporaries to call cudnnDestroy on the handle.
|
| 40 |
+
// We can achieve safety (for the narrow case of stashing within std::vectors)
|
| 41 |
+
// by making Handle moveable but not copyable, and transferring handle ownership
|
| 42 |
+
// to the latest constructed object. This is not a substitute for full-blown
|
| 43 |
+
// reference counting, but reference counting may be overkill here.
|
| 44 |
+
// Another alternative is to wrap the saved Handles in unique_ptrs, i.e.,
|
| 45 |
+
// unordered_map<int, vector<unique_ptr<Handle>>> created_handles;
|
| 46 |
+
Handle(const Handle& rhs) = delete;
|
| 47 |
+
// Following https://stackoverflow.com/questions/3279543/what-is-the-copy-and-swap-idiom
|
| 48 |
+
Handle(Handle&& rhs) noexcept : Handle() { std::swap(handle, rhs.handle); }
|
| 49 |
+
// operator= takes argument by value
|
| 50 |
+
Handle& operator=(Handle rhs) { std::swap(handle, rhs.handle); return *this; }
|
| 51 |
+
~Handle() {
|
| 52 |
+
if(handle) Destroy(handle);
|
| 53 |
+
}
|
| 54 |
+
};
|
| 55 |
+
|
| 56 |
+
std::mutex mutex;
|
| 57 |
+
|
| 58 |
+
// Handles are lazily created as different threads request them,
|
| 59 |
+
// but are never destroyed until the end of the process.
|
| 60 |
+
// The maximum number of handles this process will create for each device is equal
|
| 61 |
+
// to the high-water mark of the number of concurrently active threads that request
|
| 62 |
+
// handles for that device.
|
| 63 |
+
// When threads terminate, they release their handles back into the pool for reuse.
|
| 64 |
+
// Otherwise, new handles would be created every time new threads were spawned,
|
| 65 |
+
// resulting in poor performance for Python modules that repeatedly or frequently
|
| 66 |
+
// spawned new sets of threads (like DataParallel, which creates a new set of threads
|
| 67 |
+
// for each forward pass).
|
| 68 |
+
//
|
| 69 |
+
// To prevent potential deadlocks, we explicitly choose not to cap the number
|
| 70 |
+
// of handles that are created per device.
|
| 71 |
+
// Example of danger: If we cap the max handles at 4, and 5 threads are sharing a device,
|
| 72 |
+
// only 4 can make forward progress at any time. The other 4 will not release their
|
| 73 |
+
// handles until they exit, so the fifth cannot make progress until then. This is
|
| 74 |
+
// not a problem...UNLESS all 5 threads attempt some sort of synchronization at an
|
| 75 |
+
// intermediate point (ie, before any of them have exited). We have no way to anticipate
|
| 76 |
+
// or enforce that user threads will not attempt such intermediate synchronization.
|
| 77 |
+
// The only way to ensure safety is to avoid imposing a cap on the number of handles.
|
| 78 |
+
std::unordered_map<int, std::vector<Handle>> created_handles;
|
| 79 |
+
std::unordered_map<int, std::vector<Handle_t>> available_handles;
|
| 80 |
+
|
| 81 |
+
// PoolWindow lazily creates and caches the handles that a particular thread is using,
|
| 82 |
+
// so in the common case handle access doesn't incur either handle creation or a mutex lock.
|
| 83 |
+
class PoolWindow
|
| 84 |
+
{
|
| 85 |
+
public:
|
| 86 |
+
PoolWindow(std::shared_ptr<DeviceThreadHandlePool> parent): weak_parent(std::move(parent)) {}
|
| 87 |
+
~PoolWindow(){ release(); }
|
| 88 |
+
|
| 89 |
+
Handle_t reserve(int device)
|
| 90 |
+
{
|
| 91 |
+
// If this thread already has a handle for this device, return it
|
| 92 |
+
if(my_handles.find(device) != my_handles.end())
|
| 93 |
+
return my_handles[device];
|
| 94 |
+
|
| 95 |
+
// otherwise, either grab a handle from the pool if one is available,
|
| 96 |
+
// or if not, create a new one.
|
| 97 |
+
auto parent = weak_parent.lock();
|
| 98 |
+
TORCH_CHECK(parent, "Cannot create handle during program termination");
|
| 99 |
+
std::lock_guard<std::mutex> guard(parent->mutex);
|
| 100 |
+
|
| 101 |
+
if(parent->available_handles[device].size() > 0)
|
| 102 |
+
{
|
| 103 |
+
my_handles[device] = parent->available_handles[device].back();
|
| 104 |
+
parent->available_handles[device].pop_back();
|
| 105 |
+
}
|
| 106 |
+
else
|
| 107 |
+
{
|
| 108 |
+
// In local testing, I do observe that emplace_back sometimes routes through temporaries
|
| 109 |
+
// that incur move-constructor and destructor calls. See comments in Handle above.
|
| 110 |
+
parent->created_handles[device].emplace_back(true /*create*/);
|
| 111 |
+
my_handles[device] = parent->created_handles[device].back().handle;
|
| 112 |
+
}
|
| 113 |
+
|
| 114 |
+
return my_handles[device];
|
| 115 |
+
}
|
| 116 |
+
|
| 117 |
+
private:
|
| 118 |
+
// Stores the per-device handles currently owned by this thread
|
| 119 |
+
std::unordered_map<int, Handle_t> my_handles;
|
| 120 |
+
|
| 121 |
+
std::weak_ptr<DeviceThreadHandlePool> weak_parent;
|
| 122 |
+
|
| 123 |
+
// Called by the destructor. Releases this thread's handles back into the pool.
|
| 124 |
+
void release() {
|
| 125 |
+
if(my_handles.size() > 0) {
|
| 126 |
+
auto parent = weak_parent.lock();
|
| 127 |
+
if (!parent) {
|
| 128 |
+
// If this thread exits after atexit handlers have completed, the
|
| 129 |
+
// cuda context itself may be invalid, so we must leak the handles.
|
| 130 |
+
return;
|
| 131 |
+
}
|
| 132 |
+
|
| 133 |
+
std::lock_guard<std::mutex> guard(parent->mutex);
|
| 134 |
+
for(auto d_h : my_handles)
|
| 135 |
+
parent->available_handles[d_h.first].push_back(d_h.second);
|
| 136 |
+
}
|
| 137 |
+
}
|
| 138 |
+
};
|
| 139 |
+
|
| 140 |
+
// Warning:
|
| 141 |
+
// If you want to change this function, be aware that this function will be called
|
| 142 |
+
// by multiple threads and there is no mutex guarding the call of this function, so
|
| 143 |
+
// make sure your implementation is thread-safe.
|
| 144 |
+
PoolWindow *newPoolWindow() {
|
| 145 |
+
// The returned pointer will be owned by a thread local variable
|
| 146 |
+
// so that different threads does not share the same PoolWindow.
|
| 147 |
+
return new PoolWindow(this->shared_from_this());
|
| 148 |
+
}
|
| 149 |
+
};
|
| 150 |
+
|
| 151 |
+
}} // namespace at::cuda::detail::<anonymous>
|
.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/detail/IntegerDivider.cuh
ADDED
|
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <assert.h>
|
| 4 |
+
#if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__)
|
| 5 |
+
#include <cuda_runtime.h>
|
| 6 |
+
#endif
|
| 7 |
+
|
| 8 |
+
namespace at::cuda::detail {
|
| 9 |
+
|
| 10 |
+
// A utility class to implement integer division by multiplication, given a fixed
|
| 11 |
+
// divisor.
|
| 12 |
+
//
|
| 13 |
+
// WARNING: The fast divider algorithm is only implemented for unsigned int;
|
| 14 |
+
// otherwise we default to plain integer division. For unsigned int,
|
| 15 |
+
// we further assume that the dividend is at most INT32_MAX. Thus,
|
| 16 |
+
// IntDivider must NOT be used for general integer division.
|
| 17 |
+
//
|
| 18 |
+
// This reduced range is enough for our purpose, and it allows us to
|
| 19 |
+
// slightly simplify the computation.
|
| 20 |
+
//
|
| 21 |
+
// (NOTE: Below, "2^k" denotes exponentiation, i.e., 1<<k.)
|
| 22 |
+
//
|
| 23 |
+
// For any N-bit unsigned integer d (> 0), we can find a "magic number" m (2^N
|
| 24 |
+
// <= m < 2^(N+1)) and shift s such that:
|
| 25 |
+
//
|
| 26 |
+
// \floor(n / d) = \floor((m * n) / 2^(N+s)).
|
| 27 |
+
//
|
| 28 |
+
// Given such m and s, the integer division can be then implemented as:
|
| 29 |
+
//
|
| 30 |
+
// let m' = m - 2^N // 0 <= m' < 2^N
|
| 31 |
+
//
|
| 32 |
+
// fast_integer_division(n):
|
| 33 |
+
// // Multiply two N-bit unsigned integers: the result is a 2N-bit unsigned
|
| 34 |
+
// // integer. Then take the higher N bits.
|
| 35 |
+
// t = (m' * n) >> N
|
| 36 |
+
//
|
| 37 |
+
// // Here we use the fact that n is less than 2^(N-1): otherwise the value
|
| 38 |
+
// // of (t + n) may not fit in an N-bit integer.
|
| 39 |
+
// return (t + n) >> s
|
| 40 |
+
//
|
| 41 |
+
// Finding such a magic number is surprisingly easy:
|
| 42 |
+
//
|
| 43 |
+
// s = \ceil(\log_2 d)
|
| 44 |
+
// m' = \floor(2^N * (2^s - d) / d) + 1 // Need 2N-bit integer arithmetic.
|
| 45 |
+
//
|
| 46 |
+
// See also:
|
| 47 |
+
// - Division by Invariant Integers Using Multiplication,
|
| 48 |
+
// Torbjörn Granlund and Peter L. Montgomery, 1994.
|
| 49 |
+
//
|
| 50 |
+
// - http://www.hackersdelight.org/magic.htm
|
| 51 |
+
//
|
| 52 |
+
// - http://ridiculousfish.com/blog/posts/labor-of-division-episode-i.html
|
| 53 |
+
|
| 54 |
+
// Result of div/mod operation stored together.
|
| 55 |
+
template <typename Value>
|
| 56 |
+
struct DivMod {
|
| 57 |
+
Value div, mod;
|
| 58 |
+
|
| 59 |
+
C10_HOST_DEVICE DivMod(Value div, Value mod) : div(div), mod(mod) { }
|
| 60 |
+
};
|
| 61 |
+
|
| 62 |
+
// Base case: we only have an implementation for uint32_t for now. For
|
| 63 |
+
// everything else, we use plain division.
|
| 64 |
+
template <typename Value>
|
| 65 |
+
struct IntDivider {
|
| 66 |
+
IntDivider() = default;
|
| 67 |
+
IntDivider(Value d) : divisor(d) { }
|
| 68 |
+
|
| 69 |
+
C10_HOST_DEVICE inline Value div(Value n) const { return n / divisor; }
|
| 70 |
+
C10_HOST_DEVICE inline Value mod(Value n) const { return n % divisor; }
|
| 71 |
+
C10_HOST_DEVICE inline DivMod<Value> divmod(Value n) const {
|
| 72 |
+
return DivMod<Value>(n / divisor, n % divisor);
|
| 73 |
+
}
|
| 74 |
+
|
| 75 |
+
Value divisor;
|
| 76 |
+
};
|
| 77 |
+
|
| 78 |
+
// Implement fast integer division.
|
| 79 |
+
template <>
|
| 80 |
+
struct IntDivider<unsigned int> {
|
| 81 |
+
static_assert(sizeof(unsigned int) == 4, "Assumes 32-bit unsigned int.");
|
| 82 |
+
|
| 83 |
+
IntDivider() = default;
|
| 84 |
+
|
| 85 |
+
IntDivider(unsigned int d) : divisor(d) {
|
| 86 |
+
assert(divisor >= 1 && divisor <= INT32_MAX);
|
| 87 |
+
|
| 88 |
+
// TODO: gcc/clang has __builtin_clz() but it's not portable.
|
| 89 |
+
for (shift = 0; shift < 32; shift++) if ((1U << shift) >= divisor) break;
|
| 90 |
+
|
| 91 |
+
uint64_t one = 1;
|
| 92 |
+
uint64_t magic = ((one << 32) * ((one << shift) - divisor)) / divisor + 1;
|
| 93 |
+
m1 = magic;
|
| 94 |
+
assert(m1 > 0 && m1 == magic); // m1 must fit in 32 bits.
|
| 95 |
+
}
|
| 96 |
+
|
| 97 |
+
C10_HOST_DEVICE inline unsigned int div(unsigned int n) const {
|
| 98 |
+
#if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__)
|
| 99 |
+
// 't' is the higher 32-bits of unsigned 32-bit multiplication of 'n' and
|
| 100 |
+
// 'm1'.
|
| 101 |
+
unsigned int t = __umulhi(n, m1);
|
| 102 |
+
return (t + n) >> shift;
|
| 103 |
+
#else
|
| 104 |
+
// Using uint64_t so that the addition does not overflow.
|
| 105 |
+
uint64_t t = ((uint64_t) n * m1) >> 32;
|
| 106 |
+
return (t + n) >> shift;
|
| 107 |
+
#endif
|
| 108 |
+
}
|
| 109 |
+
|
| 110 |
+
C10_HOST_DEVICE inline unsigned int mod(unsigned int n) const {
|
| 111 |
+
return n - div(n) * divisor;
|
| 112 |
+
}
|
| 113 |
+
|
| 114 |
+
C10_HOST_DEVICE inline DivMod<unsigned int> divmod(unsigned int n) const {
|
| 115 |
+
unsigned int q = div(n);
|
| 116 |
+
return DivMod<unsigned int>(q, n - q * divisor);
|
| 117 |
+
}
|
| 118 |
+
|
| 119 |
+
unsigned int divisor; // d above.
|
| 120 |
+
unsigned int m1; // Magic number: m' above.
|
| 121 |
+
unsigned int shift; // Shift amounts.
|
| 122 |
+
};
|
| 123 |
+
|
| 124 |
+
} // namespace at::cuda::detail
|
.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/detail/KernelUtils.h
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <limits>
|
| 4 |
+
#include <c10/util/Exception.h>
|
| 5 |
+
|
| 6 |
+
namespace at::cuda::detail {
|
| 7 |
+
|
| 8 |
+
// CUDA: grid stride looping
|
| 9 |
+
//
|
| 10 |
+
// int64_t _i_n_d_e_x specifically prevents overflow in the loop increment.
|
| 11 |
+
// If input.numel() < INT_MAX, _i_n_d_e_x < INT_MAX, except after the final
|
| 12 |
+
// iteration of the loop where _i_n_d_e_x += blockDim.x * gridDim.x can be
|
| 13 |
+
// greater than INT_MAX. But in that case _i_n_d_e_x >= n, so there are no
|
| 14 |
+
// further iterations and the overflowed value in i=_i_n_d_e_x is not used.
|
| 15 |
+
#define CUDA_KERNEL_LOOP_TYPE(i, n, index_type) \
|
| 16 |
+
int64_t _i_n_d_e_x = blockIdx.x * blockDim.x + threadIdx.x; \
|
| 17 |
+
for (index_type i=_i_n_d_e_x; _i_n_d_e_x < (n); _i_n_d_e_x+=blockDim.x * gridDim.x, i=_i_n_d_e_x)
|
| 18 |
+
|
| 19 |
+
#define CUDA_KERNEL_LOOP(i, n) CUDA_KERNEL_LOOP_TYPE(i, n, int)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
// Use 1024 threads per block, which requires cuda sm_2x or above
|
| 23 |
+
constexpr int CUDA_NUM_THREADS = 1024;
|
| 24 |
+
|
| 25 |
+
// CUDA: number of blocks for threads.
|
| 26 |
+
inline int GET_BLOCKS(const int64_t N, const int64_t max_threads_per_block=CUDA_NUM_THREADS) {
|
| 27 |
+
TORCH_INTERNAL_ASSERT(N > 0, "CUDA kernel launch blocks must be positive, but got N=", N);
|
| 28 |
+
constexpr int64_t max_int = std::numeric_limits<int>::max();
|
| 29 |
+
|
| 30 |
+
// Round up division for positive number that cannot cause integer overflow
|
| 31 |
+
auto block_num = (N - 1) / max_threads_per_block + 1;
|
| 32 |
+
TORCH_INTERNAL_ASSERT(block_num <= max_int, "Can't schedule too many blocks on CUDA device");
|
| 33 |
+
|
| 34 |
+
return static_cast<int>(block_num);
|
| 35 |
+
}
|
| 36 |
+
|
| 37 |
+
} // namespace at::cuda::detail
|
.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/detail/LazyNVRTC.h
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
#include <ATen/detail/CUDAHooksInterface.h>
|
| 3 |
+
namespace at::cuda {
|
| 4 |
+
// Forward-declares at::cuda::NVRTC
|
| 5 |
+
struct NVRTC;
|
| 6 |
+
|
| 7 |
+
namespace detail {
|
| 8 |
+
extern NVRTC lazyNVRTC;
|
| 9 |
+
} // namespace detail
|
| 10 |
+
|
| 11 |
+
} // namespace at::cuda
|
.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/detail/PhiloxCudaStateRaw.cuh
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// No "#pragma once" because this is a raw definition that can be copied by jit codegen.
|
| 2 |
+
// Eager mode clients should not include this file directly, instead,
|
| 3 |
+
// they should #include <ATen/cuda/PhiloxCudaState.h>, which has a #pragma once.
|
| 4 |
+
|
| 5 |
+
// Stores RNG state values. Passed as a kernel argument.
|
| 6 |
+
// See Note [CUDA Graph-safe RNG states].
|
| 7 |
+
//
|
| 8 |
+
// The raw definition lives in its own file so jit codegen can easily copy it.
|
| 9 |
+
namespace at {
|
| 10 |
+
|
| 11 |
+
struct PhiloxCudaState {
|
| 12 |
+
PhiloxCudaState() = default;
|
| 13 |
+
// Called if graph capture is not underway
|
| 14 |
+
PhiloxCudaState(uint64_t seed,
|
| 15 |
+
uint64_t offset) {
|
| 16 |
+
seed_.val = seed;
|
| 17 |
+
offset_.val = offset;
|
| 18 |
+
}
|
| 19 |
+
// Called if graph capture is underway
|
| 20 |
+
PhiloxCudaState(int64_t* seed,
|
| 21 |
+
int64_t* offset_extragraph,
|
| 22 |
+
uint32_t offset_intragraph) {
|
| 23 |
+
seed_.ptr = seed;
|
| 24 |
+
offset_.ptr = offset_extragraph;
|
| 25 |
+
offset_intragraph_ = offset_intragraph;
|
| 26 |
+
captured_ = true;
|
| 27 |
+
}
|
| 28 |
+
|
| 29 |
+
// Public members, directly accessible by at::cuda::philox::unpack.
|
| 30 |
+
// If we made them private with getters/setters, the getters/setters
|
| 31 |
+
// would have to be __device__, and we can't declare __device__ in ATen.
|
| 32 |
+
union Payload {
|
| 33 |
+
uint64_t val;
|
| 34 |
+
int64_t* ptr;
|
| 35 |
+
};
|
| 36 |
+
|
| 37 |
+
Payload seed_{};
|
| 38 |
+
Payload offset_{};
|
| 39 |
+
uint32_t offset_intragraph_ = 0;
|
| 40 |
+
bool captured_ = false;
|
| 41 |
+
};
|
| 42 |
+
|
| 43 |
+
} // namespace at
|
.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/detail/TensorInfo.cuh
ADDED
|
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/CollapseDims.h>
|
| 4 |
+
|
| 5 |
+
namespace at::cuda::detail {
|
| 6 |
+
|
| 7 |
+
#define MAX_TENSORINFO_DIMS 25
|
| 8 |
+
|
| 9 |
+
// CUDA kernel argument that defines tensor layout
|
| 10 |
+
template <typename T, typename IndexType>
|
| 11 |
+
struct TensorInfo {
|
| 12 |
+
TensorInfo();
|
| 13 |
+
TensorInfo(T* p,
|
| 14 |
+
int dim,
|
| 15 |
+
IndexType sz[MAX_TENSORINFO_DIMS],
|
| 16 |
+
IndexType st[MAX_TENSORINFO_DIMS]);
|
| 17 |
+
|
| 18 |
+
// Set the size of the given dimension to 1, as if it were a
|
| 19 |
+
// reduction dim (allows you to calculate offsets of the reduction
|
| 20 |
+
// slice)
|
| 21 |
+
void reduceDim(int dim);
|
| 22 |
+
|
| 23 |
+
// See note on [collapse dims].
|
| 24 |
+
int collapseDims(const int excludeDim = -1);
|
| 25 |
+
|
| 26 |
+
// Contiguous tensors of more than one dimension are collapsed down
|
| 27 |
+
// to one tensor
|
| 28 |
+
__host__ __device__ inline bool isContiguous() const {
|
| 29 |
+
return (dims == 1 && strides[0] == 1);
|
| 30 |
+
}
|
| 31 |
+
|
| 32 |
+
T* data;
|
| 33 |
+
IndexType sizes[MAX_TENSORINFO_DIMS];
|
| 34 |
+
IndexType strides[MAX_TENSORINFO_DIMS];
|
| 35 |
+
int dims;
|
| 36 |
+
};
|
| 37 |
+
|
| 38 |
+
template <typename T, typename IndexType>
|
| 39 |
+
TensorInfo<T, IndexType>::TensorInfo() {
|
| 40 |
+
data = nullptr;
|
| 41 |
+
dims = 0;
|
| 42 |
+
}
|
| 43 |
+
|
| 44 |
+
template <typename T, typename IndexType>
|
| 45 |
+
TensorInfo<T, IndexType>::TensorInfo(T* p,
|
| 46 |
+
int dim,
|
| 47 |
+
IndexType sz[MAX_TENSORINFO_DIMS],
|
| 48 |
+
IndexType st[MAX_TENSORINFO_DIMS]) {
|
| 49 |
+
data = p;
|
| 50 |
+
dims = dim;
|
| 51 |
+
TORCH_CHECK(dims < MAX_TENSORINFO_DIMS, "CUDA Tensors cannot have more than 25 dimensions");
|
| 52 |
+
|
| 53 |
+
for (int i = 0; i < dim; ++i) {
|
| 54 |
+
sizes[i] = sz[i];
|
| 55 |
+
strides[i] = st[i];
|
| 56 |
+
}
|
| 57 |
+
}
|
| 58 |
+
|
| 59 |
+
template <typename T, typename IndexType>
|
| 60 |
+
void
|
| 61 |
+
TensorInfo<T, IndexType>::reduceDim(int dim) {
|
| 62 |
+
TORCH_CHECK(dim < dims && dim >= 0, "expected dim between 0 and dims - 1");
|
| 63 |
+
sizes[dim] = 1;
|
| 64 |
+
}
|
| 65 |
+
|
| 66 |
+
template <typename T, typename IndexType>
|
| 67 |
+
int
|
| 68 |
+
TensorInfo<T, IndexType>::collapseDims(const int excludeDim) {
|
| 69 |
+
auto result = at::collapse_dims(sizes, strides, dims, excludeDim);
|
| 70 |
+
dims = std::get<1>(result);
|
| 71 |
+
return std::get<0>(result);
|
| 72 |
+
}
|
| 73 |
+
|
| 74 |
+
// Translate a linear index for the apply to a T* offset;
|
| 75 |
+
// specialized on `Dims` to reduce nvcc compilation time
|
| 76 |
+
template <typename T, typename IndexType, int Dims>
|
| 77 |
+
struct IndexToOffset {
|
| 78 |
+
static __host__ __device__ IndexType get(
|
| 79 |
+
IndexType linearId,
|
| 80 |
+
const TensorInfo<T, IndexType>& info) {
|
| 81 |
+
|
| 82 |
+
IndexType offset = 0;
|
| 83 |
+
|
| 84 |
+
// Uses static dims
|
| 85 |
+
for (int i = Dims - 1; i > 0; --i) {
|
| 86 |
+
IndexType curDimIndex = linearId % info.sizes[i];
|
| 87 |
+
IndexType curDimOffset = curDimIndex * info.strides[i];
|
| 88 |
+
offset += curDimOffset;
|
| 89 |
+
linearId /= info.sizes[i];
|
| 90 |
+
}
|
| 91 |
+
|
| 92 |
+
return offset + linearId * info.strides[0];
|
| 93 |
+
}
|
| 94 |
+
};
|
| 95 |
+
|
| 96 |
+
// Uses dynamic (runtime) instead of static (compiletime) dims
|
| 97 |
+
template <typename T, typename IndexType>
|
| 98 |
+
struct IndexToOffset<T, IndexType, -1> {
|
| 99 |
+
static inline __host__ __device__ IndexType get(
|
| 100 |
+
IndexType linearId,
|
| 101 |
+
const TensorInfo<T, IndexType>& info) {
|
| 102 |
+
|
| 103 |
+
IndexType offset = 0;
|
| 104 |
+
|
| 105 |
+
for (int i = info.dims - 1; i > 0; --i) {
|
| 106 |
+
IndexType curDimIndex = linearId % info.sizes[i];
|
| 107 |
+
IndexType curDimOffset = curDimIndex * info.strides[i];
|
| 108 |
+
offset += curDimOffset;
|
| 109 |
+
linearId /= info.sizes[i];
|
| 110 |
+
}
|
| 111 |
+
|
| 112 |
+
return offset + linearId * info.strides[0];
|
| 113 |
+
}
|
| 114 |
+
};
|
| 115 |
+
|
| 116 |
+
} // namespace at::cuda::detail
|
.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/detail/UnpackRaw.cuh
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// No "#pragma once" because this is a raw definition that can be copied by jit codegen.
|
| 2 |
+
// Eager mode clients should not include this file directly, instead,
|
| 3 |
+
// they should #include <ATen/cuda/PhiloxUtils.cuh>, which has a #pragma once.
|
| 4 |
+
|
| 5 |
+
namespace at::cuda::philox {
|
| 6 |
+
|
| 7 |
+
// In-kernel call to retrieve philox seed and offset from a PhiloxCudaState instance whether
|
| 8 |
+
// that instance was created with graph capture underway or not.
|
| 9 |
+
// See Note [CUDA Graph-safe RNG states].
|
| 10 |
+
//
|
| 11 |
+
// We can't write a __device__ function in CUDAGeneratorImpl.h, because it's in ATen.
|
| 12 |
+
// Also, whatever call unpacks PhiloxCudaState in consumer kernels must be inlineable.
|
| 13 |
+
// Easiest thing that comes to mind is, define a __device__ unpack helper here, in ATen/cuda.
|
| 14 |
+
//
|
| 15 |
+
// The raw definition lives in its own file so jit codegen can easily copy it.
|
| 16 |
+
__host__ __device__ __forceinline__ std::tuple<uint64_t, uint64_t>
|
| 17 |
+
unpack(at::PhiloxCudaState arg) {
|
| 18 |
+
if (arg.captured_) {
|
| 19 |
+
// static_cast avoids "warning: invalid narrowing conversion from "long" to "unsigned long".
|
| 20 |
+
// *(arg.offset_.ptr) is a broadcast load of a single int64_t to the entire kernel.
|
| 21 |
+
// For most threads' reads it will hit in cache, so it shouldn't hurt performance.
|
| 22 |
+
return std::make_tuple(static_cast<uint64_t>(*arg.seed_.ptr), static_cast<uint64_t>(*(arg.offset_.ptr) + arg.offset_intragraph_));
|
| 23 |
+
} else {
|
| 24 |
+
return std::make_tuple(arg.seed_.val, arg.offset_.val);
|
| 25 |
+
}
|
| 26 |
+
}
|
| 27 |
+
|
| 28 |
+
} // namespace at::cuda::philox
|
.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/llvm_jit_strings.h
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <string>
|
| 4 |
+
#include <c10/macros/Export.h>
|
| 5 |
+
|
| 6 |
+
namespace at::cuda {
|
| 7 |
+
|
| 8 |
+
TORCH_CUDA_CPP_API const std::string &get_traits_string();
|
| 9 |
+
TORCH_CUDA_CPP_API const std::string &get_cmath_string();
|
| 10 |
+
TORCH_CUDA_CPP_API const std::string &get_complex_body_string();
|
| 11 |
+
TORCH_CUDA_CPP_API const std::string &get_complex_half_body_string();
|
| 12 |
+
TORCH_CUDA_CPP_API const std::string &get_complex_math_string();
|
| 13 |
+
|
| 14 |
+
} // namespace at::cuda
|
.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/tunable/GemmCommon.h
ADDED
|
@@ -0,0 +1,397 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Original TunableOp is from onnxruntime.
|
| 2 |
+
// https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/core/framework/tunable.h
|
| 3 |
+
// https://github.com/microsoft/onnxruntime/tree/main/onnxruntime/core/providers/rocm/tunable
|
| 4 |
+
// Copyright (c) Microsoft Corporation.
|
| 5 |
+
// Licensed under the MIT license.
|
| 6 |
+
//
|
| 7 |
+
// Adapting TunableOp into PyTorch
|
| 8 |
+
// Copyright (c) Advanced Micro Devices, Inc.
|
| 9 |
+
//
|
| 10 |
+
#pragma once
|
| 11 |
+
|
| 12 |
+
#include <string>
|
| 13 |
+
|
| 14 |
+
#include <ATen/cuda/tunable/TunableOp.h>
|
| 15 |
+
#include <ATen/cuda/Exceptions.h>
|
| 16 |
+
#include <c10/util/StringUtil.h>
|
| 17 |
+
|
| 18 |
+
#ifndef AT_PER_OPERATOR_HEADERS
|
| 19 |
+
#include <ATen/Functions.h>
|
| 20 |
+
#include <ATen/NativeFunctions.h>
|
| 21 |
+
#else
|
| 22 |
+
#include <ATen/ops/allclose.h>
|
| 23 |
+
#include <ATen/ops/from_blob.h>
|
| 24 |
+
#endif
|
| 25 |
+
|
| 26 |
+
namespace at::cuda::tunable {
|
| 27 |
+
|
| 28 |
+
enum class BlasOp {
|
| 29 |
+
N = 0,
|
| 30 |
+
T = 1
|
| 31 |
+
};
|
| 32 |
+
|
| 33 |
+
inline std::string BlasOpToString(BlasOp op) {
|
| 34 |
+
switch (op) {
|
| 35 |
+
case BlasOp::N:
|
| 36 |
+
return "N";
|
| 37 |
+
case BlasOp::T:
|
| 38 |
+
return "T";
|
| 39 |
+
}
|
| 40 |
+
TORCH_CHECK(false, "unrecognized BlasOp");
|
| 41 |
+
return "N";
|
| 42 |
+
}
|
| 43 |
+
|
| 44 |
+
namespace detail {
|
| 45 |
+
|
| 46 |
+
static bool NumericalCheck(ScalarType dtype, void* c, void* other_c, int64_t size) {
|
| 47 |
+
auto options = at::TensorOptions().dtype(dtype).device(at::kCUDA);
|
| 48 |
+
// comparison done as 1D tensor
|
| 49 |
+
at::Tensor ref = at::from_blob(c, {size}, options);
|
| 50 |
+
at::Tensor oth = at::from_blob(other_c, {size}, options);
|
| 51 |
+
at::Tensor ref_float = ref.to(at::kFloat);
|
| 52 |
+
at::Tensor oth_float = oth.to(at::kFloat);
|
| 53 |
+
std::vector<double> atols{1e-1, 1e-2, 1e-3, 1e-4, 1e-5};
|
| 54 |
+
std::vector<double> rtols{1e-1, 1e-2, 1e-3, 1e-4, 1e-5};
|
| 55 |
+
double last_succeed_atol = 1;
|
| 56 |
+
double last_succeed_rtol = 1;
|
| 57 |
+
for (auto& atol : atols) {
|
| 58 |
+
for (auto& rtol : rtols) {
|
| 59 |
+
if (at::allclose(ref_float, oth_float, rtol, atol)) {
|
| 60 |
+
last_succeed_atol = atol;
|
| 61 |
+
last_succeed_rtol = rtol;
|
| 62 |
+
}
|
| 63 |
+
}
|
| 64 |
+
}
|
| 65 |
+
if (last_succeed_atol == 1) {
|
| 66 |
+
return false;
|
| 67 |
+
}
|
| 68 |
+
else {
|
| 69 |
+
TUNABLE_LOG3("├──verify numerics: atol=", last_succeed_atol, ", rtol=", last_succeed_rtol);
|
| 70 |
+
}
|
| 71 |
+
|
| 72 |
+
return true;
|
| 73 |
+
}
|
| 74 |
+
|
| 75 |
+
}
|
| 76 |
+
|
| 77 |
+
template <typename T>
|
| 78 |
+
struct GemmParams : OpParams {
|
| 79 |
+
GemmParams() {
|
| 80 |
+
duplicate_inputs_ = false;
|
| 81 |
+
}
|
| 82 |
+
|
| 83 |
+
std::string Signature() const override {
|
| 84 |
+
return c10::str(transa, transb, "_", m, "_", n, "_", k);
|
| 85 |
+
}
|
| 86 |
+
|
| 87 |
+
size_t GetSizeA() const {
|
| 88 |
+
return sizeof(T) * lda * ((transa == 'n' || transa == 'N') ? k : m);
|
| 89 |
+
}
|
| 90 |
+
|
| 91 |
+
size_t GetSizeB() const {
|
| 92 |
+
return sizeof(T) * ldb * ((transb == 'n' || transb == 'N') ? n : k);
|
| 93 |
+
}
|
| 94 |
+
|
| 95 |
+
size_t GetSizeC() const {
|
| 96 |
+
return sizeof(T) * ldc * n;
|
| 97 |
+
}
|
| 98 |
+
|
| 99 |
+
size_t GetSize(bool duplicate_inputs) const {
|
| 100 |
+
size_t size = GetSizeC();
|
| 101 |
+
if (duplicate_inputs) {
|
| 102 |
+
size += GetSizeA();
|
| 103 |
+
size += GetSizeB();
|
| 104 |
+
}
|
| 105 |
+
return size;
|
| 106 |
+
}
|
| 107 |
+
|
| 108 |
+
GemmParams* DeepCopy(bool duplicate_inputs) const {
|
| 109 |
+
GemmParams* copy = new GemmParams;
|
| 110 |
+
*copy = *this;
|
| 111 |
+
c10::DeviceIndex device = 0;
|
| 112 |
+
AT_CUDA_CHECK(c10::cuda::GetDevice(&device));
|
| 113 |
+
size_t c_size = GetSizeC();
|
| 114 |
+
copy->c = static_cast<T*>(c10::cuda::CUDACachingAllocator::raw_alloc(c_size));
|
| 115 |
+
AT_CUDA_CHECK(c10::cuda::CUDACachingAllocator::memcpyAsync(
|
| 116 |
+
copy->c, device, c, device, c_size, getCurrentCUDAStream(device), true));
|
| 117 |
+
if (duplicate_inputs) {
|
| 118 |
+
size_t a_size = GetSizeA();
|
| 119 |
+
size_t b_size = GetSizeB();
|
| 120 |
+
copy->a = static_cast<const T*>(c10::cuda::CUDACachingAllocator::raw_alloc(a_size));
|
| 121 |
+
copy->b = static_cast<const T*>(c10::cuda::CUDACachingAllocator::raw_alloc(b_size));
|
| 122 |
+
copy->duplicate_inputs_ = true;
|
| 123 |
+
}
|
| 124 |
+
return copy;
|
| 125 |
+
}
|
| 126 |
+
|
| 127 |
+
// only call on object returned by DeepCopy
|
| 128 |
+
void Delete() {
|
| 129 |
+
c10::cuda::CUDACachingAllocator::raw_delete(c);
|
| 130 |
+
if (duplicate_inputs_) {
|
| 131 |
+
c10::cuda::CUDACachingAllocator::raw_delete(const_cast<T*>(a));
|
| 132 |
+
c10::cuda::CUDACachingAllocator::raw_delete(const_cast<T*>(b));
|
| 133 |
+
}
|
| 134 |
+
}
|
| 135 |
+
|
| 136 |
+
TuningStatus NumericalCheck(GemmParams<T> *other) {
|
| 137 |
+
auto c_dtype = c10::CppTypeToScalarType<T>::value;
|
| 138 |
+
return detail::NumericalCheck(c_dtype, c, other->c, ldc*n) ? OK : FAIL;
|
| 139 |
+
}
|
| 140 |
+
|
| 141 |
+
char transa;
|
| 142 |
+
char transb;
|
| 143 |
+
int64_t m;
|
| 144 |
+
int64_t n;
|
| 145 |
+
int64_t k;
|
| 146 |
+
at::opmath_type<T> alpha;
|
| 147 |
+
const T* a;
|
| 148 |
+
int64_t lda;
|
| 149 |
+
const T* b;
|
| 150 |
+
int64_t ldb;
|
| 151 |
+
at::opmath_type<T> beta;
|
| 152 |
+
T* c;
|
| 153 |
+
int64_t ldc;
|
| 154 |
+
private:
|
| 155 |
+
bool duplicate_inputs_;
|
| 156 |
+
};
|
| 157 |
+
|
| 158 |
+
template <typename T>
|
| 159 |
+
struct GemmAndBiasParams : OpParams {
|
| 160 |
+
std::string Signature() const override {
|
| 161 |
+
return c10::str(transa, transb, "_", m, "_", n, "_", k);
|
| 162 |
+
}
|
| 163 |
+
|
| 164 |
+
size_t GetSize(bool duplicate_inputs) const {
|
| 165 |
+
size_t size = sizeof(T) * ldc * n;
|
| 166 |
+
if (duplicate_inputs) {
|
| 167 |
+
size += sizeof(T) * lda * ((transa == 'n' || transa == 'N') ? k : m);
|
| 168 |
+
size += sizeof(T) * ldb * ((transb == 'n' || transb == 'N') ? n : k);
|
| 169 |
+
}
|
| 170 |
+
return size;
|
| 171 |
+
}
|
| 172 |
+
|
| 173 |
+
GemmAndBiasParams* DeepCopy(bool duplicate_inputs) const {
|
| 174 |
+
GemmAndBiasParams* copy = new GemmAndBiasParams;
|
| 175 |
+
*copy = *this;
|
| 176 |
+
c10::DeviceIndex device = 0;
|
| 177 |
+
AT_CUDA_CHECK(c10::cuda::GetDevice(&device));
|
| 178 |
+
size_t c_size = ldc * n * sizeof(T);
|
| 179 |
+
copy->c = static_cast<T*>(c10::cuda::CUDACachingAllocator::raw_alloc(c_size));
|
| 180 |
+
AT_CUDA_CHECK(c10::cuda::CUDACachingAllocator::memcpyAsync(
|
| 181 |
+
copy->c, device, c, device, c_size, getCurrentCUDAStream(device), true));
|
| 182 |
+
if (duplicate_inputs) {
|
| 183 |
+
size_t a_size = sizeof(T) * lda * ((transa == 'n' || transa == 'N') ? k : m);
|
| 184 |
+
size_t b_size = sizeof(T) * ldb * ((transb == 'n' || transb == 'N') ? n : k);
|
| 185 |
+
copy->a = static_cast<const T*>(c10::cuda::CUDACachingAllocator::raw_alloc(a_size));
|
| 186 |
+
copy->b = static_cast<const T*>(c10::cuda::CUDACachingAllocator::raw_alloc(b_size));
|
| 187 |
+
copy->duplicate_inputs_ = true;
|
| 188 |
+
}
|
| 189 |
+
return copy;
|
| 190 |
+
}
|
| 191 |
+
|
| 192 |
+
// only call on object returned by DeepCopy
|
| 193 |
+
void Delete() {
|
| 194 |
+
c10::cuda::CUDACachingAllocator::raw_delete(c);
|
| 195 |
+
if (duplicate_inputs_) {
|
| 196 |
+
c10::cuda::CUDACachingAllocator::raw_delete(const_cast<T*>(a));
|
| 197 |
+
c10::cuda::CUDACachingAllocator::raw_delete(const_cast<T*>(b));
|
| 198 |
+
}
|
| 199 |
+
}
|
| 200 |
+
|
| 201 |
+
TuningStatus NumericalCheck(GemmAndBiasParams<T> *other) {
|
| 202 |
+
auto c_dtype = c10::CppTypeToScalarType<T>::value;
|
| 203 |
+
return detail::NumericalCheck(c_dtype, c, other->c, ldc*n) ? OK : FAIL;
|
| 204 |
+
}
|
| 205 |
+
|
| 206 |
+
char transa;
|
| 207 |
+
char transb;
|
| 208 |
+
int64_t m;
|
| 209 |
+
int64_t n;
|
| 210 |
+
int64_t k;
|
| 211 |
+
at::opmath_type<T> alpha;
|
| 212 |
+
const T* a;
|
| 213 |
+
int64_t lda;
|
| 214 |
+
const T* b;
|
| 215 |
+
int64_t ldb;
|
| 216 |
+
T* c;
|
| 217 |
+
int64_t ldc;
|
| 218 |
+
const T* bias;
|
| 219 |
+
at::cuda::blas::GEMMAndBiasActivationEpilogue activation;
|
| 220 |
+
private:
|
| 221 |
+
bool duplicate_inputs_;
|
| 222 |
+
};
|
| 223 |
+
|
| 224 |
+
template <typename T>
|
| 225 |
+
struct GemmStridedBatchedParams : OpParams {
|
| 226 |
+
GemmStridedBatchedParams() {
|
| 227 |
+
duplicate_inputs_ = false;
|
| 228 |
+
}
|
| 229 |
+
|
| 230 |
+
std::string Signature() const override {
|
| 231 |
+
return c10::str(transa, transb, "_", m, "_", n, "_", k, "_B_", batch);
|
| 232 |
+
}
|
| 233 |
+
|
| 234 |
+
size_t GetSizeA() const {
|
| 235 |
+
return sizeof(T) * std::min(lda, stride_a) * ((transa == 'n' || transa == 'N') ? k : m) * batch;
|
| 236 |
+
}
|
| 237 |
+
|
| 238 |
+
size_t GetSizeB() const {
|
| 239 |
+
return sizeof(T) * std::min(ldb, stride_b) * ((transb == 'n' || transb == 'N') ? n : k) * batch;
|
| 240 |
+
}
|
| 241 |
+
|
| 242 |
+
size_t GetSizeC() const {
|
| 243 |
+
return sizeof(T) * std::min(ldc, stride_c) * n * batch;
|
| 244 |
+
}
|
| 245 |
+
|
| 246 |
+
size_t GetSize(bool duplicate_inputs) const {
|
| 247 |
+
size_t size = GetSizeC();
|
| 248 |
+
if (duplicate_inputs) {
|
| 249 |
+
size += GetSizeA();
|
| 250 |
+
size += GetSizeB();
|
| 251 |
+
}
|
| 252 |
+
return size;
|
| 253 |
+
}
|
| 254 |
+
|
| 255 |
+
GemmStridedBatchedParams* DeepCopy(bool duplicate_inputs) const {
|
| 256 |
+
GemmStridedBatchedParams* copy = new GemmStridedBatchedParams;
|
| 257 |
+
*copy = *this;
|
| 258 |
+
c10::DeviceIndex device = 0;
|
| 259 |
+
AT_CUDA_CHECK(c10::cuda::GetDevice(&device));
|
| 260 |
+
size_t c_size = GetSizeC();
|
| 261 |
+
copy->c = static_cast<T*>(c10::cuda::CUDACachingAllocator::raw_alloc(c_size));
|
| 262 |
+
AT_CUDA_CHECK(c10::cuda::CUDACachingAllocator::memcpyAsync(
|
| 263 |
+
copy->c, device, c, device, c_size, getCurrentCUDAStream(device), true));
|
| 264 |
+
if (duplicate_inputs) {
|
| 265 |
+
size_t a_size = GetSizeA();
|
| 266 |
+
size_t b_size = GetSizeB();
|
| 267 |
+
copy->a = static_cast<const T*>(c10::cuda::CUDACachingAllocator::raw_alloc(a_size));
|
| 268 |
+
copy->b = static_cast<const T*>(c10::cuda::CUDACachingAllocator::raw_alloc(b_size));
|
| 269 |
+
copy->duplicate_inputs_ = true;
|
| 270 |
+
}
|
| 271 |
+
return copy;
|
| 272 |
+
}
|
| 273 |
+
|
| 274 |
+
// only call on object returned by DeepCopy
|
| 275 |
+
void Delete() {
|
| 276 |
+
c10::cuda::CUDACachingAllocator::raw_delete(c);
|
| 277 |
+
if (duplicate_inputs_) {
|
| 278 |
+
c10::cuda::CUDACachingAllocator::raw_delete(const_cast<T*>(a));
|
| 279 |
+
c10::cuda::CUDACachingAllocator::raw_delete(const_cast<T*>(b));
|
| 280 |
+
}
|
| 281 |
+
}
|
| 282 |
+
|
| 283 |
+
TuningStatus NumericalCheck(GemmStridedBatchedParams<T> *other) {
|
| 284 |
+
auto c_dtype = c10::CppTypeToScalarType<T>::value;
|
| 285 |
+
return detail::NumericalCheck(c_dtype, c, other->c, batch*stride_c) ? OK : FAIL;
|
| 286 |
+
}
|
| 287 |
+
|
| 288 |
+
char transa;
|
| 289 |
+
char transb;
|
| 290 |
+
int64_t m;
|
| 291 |
+
int64_t n;
|
| 292 |
+
int64_t k;
|
| 293 |
+
at::opmath_type<T> alpha;
|
| 294 |
+
const T* a;
|
| 295 |
+
int64_t lda;
|
| 296 |
+
int64_t stride_a;
|
| 297 |
+
const T* b;
|
| 298 |
+
int64_t ldb;
|
| 299 |
+
int64_t stride_b;
|
| 300 |
+
at::opmath_type<T> beta;
|
| 301 |
+
T* c;
|
| 302 |
+
int64_t ldc;
|
| 303 |
+
int64_t stride_c;
|
| 304 |
+
int64_t batch;
|
| 305 |
+
private:
|
| 306 |
+
bool duplicate_inputs_;
|
| 307 |
+
};
|
| 308 |
+
|
| 309 |
+
template <typename T>
|
| 310 |
+
struct ScaledGemmParams : OpParams {
|
| 311 |
+
ScaledGemmParams() {
|
| 312 |
+
duplicate_inputs_ = false;
|
| 313 |
+
}
|
| 314 |
+
|
| 315 |
+
std::string Signature() const override {
|
| 316 |
+
return c10::str(transa, transb, "_", m, "_", n, "_", k);
|
| 317 |
+
}
|
| 318 |
+
|
| 319 |
+
size_t GetSizeA() const {
|
| 320 |
+
return sizeof(T) * lda * ((transa == 'n' || transa == 'N') ? k : m);
|
| 321 |
+
}
|
| 322 |
+
|
| 323 |
+
size_t GetSizeB() const {
|
| 324 |
+
return sizeof(T) * ldb * ((transb == 'n' || transb == 'N') ? n : k);
|
| 325 |
+
}
|
| 326 |
+
|
| 327 |
+
size_t GetSizeC() const {
|
| 328 |
+
return sizeof(T) * ldc * n;
|
| 329 |
+
}
|
| 330 |
+
|
| 331 |
+
size_t GetSize(bool duplicate_inputs) const {
|
| 332 |
+
size_t size = GetSizeC();
|
| 333 |
+
if (duplicate_inputs) {
|
| 334 |
+
size += GetSizeA();
|
| 335 |
+
size += GetSizeB();
|
| 336 |
+
}
|
| 337 |
+
return size;
|
| 338 |
+
}
|
| 339 |
+
|
| 340 |
+
ScaledGemmParams* DeepCopy(bool duplicate_inputs) const {
|
| 341 |
+
ScaledGemmParams* copy = new ScaledGemmParams;
|
| 342 |
+
*copy = *this;
|
| 343 |
+
c10::DeviceIndex device = 0;
|
| 344 |
+
AT_CUDA_CHECK(c10::cuda::GetDevice(&device));
|
| 345 |
+
size_t c_size = GetSizeC();
|
| 346 |
+
copy->c = c10::cuda::CUDACachingAllocator::raw_alloc(c_size);
|
| 347 |
+
AT_CUDA_CHECK(c10::cuda::CUDACachingAllocator::memcpyAsync(
|
| 348 |
+
copy->c, device, c, device, c_size, getCurrentCUDAStream(device), true));
|
| 349 |
+
if (duplicate_inputs) {
|
| 350 |
+
size_t a_size = GetSizeA();
|
| 351 |
+
size_t b_size = GetSizeB();
|
| 352 |
+
copy->a = c10::cuda::CUDACachingAllocator::raw_alloc(a_size);
|
| 353 |
+
copy->b = c10::cuda::CUDACachingAllocator::raw_alloc(b_size);
|
| 354 |
+
copy->duplicate_inputs_ = true;
|
| 355 |
+
}
|
| 356 |
+
return copy;
|
| 357 |
+
}
|
| 358 |
+
|
| 359 |
+
// only call on object returned by DeepCopy
|
| 360 |
+
void Delete() {
|
| 361 |
+
c10::cuda::CUDACachingAllocator::raw_delete(c);
|
| 362 |
+
if (duplicate_inputs_) {
|
| 363 |
+
c10::cuda::CUDACachingAllocator::raw_delete(const_cast<void*>(a));
|
| 364 |
+
c10::cuda::CUDACachingAllocator::raw_delete(const_cast<void*>(b));
|
| 365 |
+
}
|
| 366 |
+
}
|
| 367 |
+
|
| 368 |
+
TuningStatus NumericalCheck(ScaledGemmParams<T> *other) {
|
| 369 |
+
return detail::NumericalCheck(c_dtype, c, other->c, ldc*n) ? OK : FAIL;
|
| 370 |
+
}
|
| 371 |
+
|
| 372 |
+
char transa;
|
| 373 |
+
char transb;
|
| 374 |
+
int64_t m;
|
| 375 |
+
int64_t n;
|
| 376 |
+
int64_t k;
|
| 377 |
+
const void* a;
|
| 378 |
+
const void* a_scale_ptr;
|
| 379 |
+
int64_t lda;
|
| 380 |
+
ScalarType a_dtype;
|
| 381 |
+
const void* b;
|
| 382 |
+
const void* b_scale_ptr;
|
| 383 |
+
int64_t ldb;
|
| 384 |
+
ScalarType b_dtype;
|
| 385 |
+
const void* bias_ptr;
|
| 386 |
+
ScalarType bias_dtype;
|
| 387 |
+
void* c;
|
| 388 |
+
const void* c_scale_ptr;
|
| 389 |
+
int64_t ldc;
|
| 390 |
+
ScalarType c_dtype;
|
| 391 |
+
void* amax_ptr;
|
| 392 |
+
bool use_fast_accum;
|
| 393 |
+
private:
|
| 394 |
+
bool duplicate_inputs_;
|
| 395 |
+
};
|
| 396 |
+
|
| 397 |
+
} // namespace at::cuda::tunable
|
.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/tunable/GemmHipblaslt.h
ADDED
|
@@ -0,0 +1,611 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Copyright (c) Microsoft Corporation. All rights reserved.
|
| 2 |
+
// Licensed under the MIT License.
|
| 3 |
+
|
| 4 |
+
#pragma once
|
| 5 |
+
|
| 6 |
+
#include <ATen/cuda/CUDAContext.h>
|
| 7 |
+
#include <ATen/cuda/CUDADataType.h>
|
| 8 |
+
#include <ATen/cuda/tunable/TunableOp.h>
|
| 9 |
+
#include <ATen/cuda/tunable/GemmCommon.h>
|
| 10 |
+
#include <c10/cuda/CUDACachingAllocator.h>
|
| 11 |
+
#include <c10/util/StringUtil.h>
|
| 12 |
+
|
| 13 |
+
#include <hipblaslt/hipblaslt.h>
|
| 14 |
+
#include <hipblaslt/hipblaslt-ext.hpp>
|
| 15 |
+
|
| 16 |
+
#define TORCH_HIPBLASLT_CHECK(EXPR) \
|
| 17 |
+
do { \
|
| 18 |
+
hipblasStatus_t __err = EXPR; \
|
| 19 |
+
TORCH_CHECK(__err == HIPBLAS_STATUS_SUCCESS, \
|
| 20 |
+
"hipblaslt error: ", \
|
| 21 |
+
hipblasStatusToString(__err), \
|
| 22 |
+
" when calling `" #EXPR "`"); \
|
| 23 |
+
} while (0)
|
| 24 |
+
|
| 25 |
+
namespace at::cuda::tunable {
|
| 26 |
+
|
| 27 |
+
template <typename T>
|
| 28 |
+
constexpr hipblasDatatype_t HipDataTypeFor();
|
| 29 |
+
|
| 30 |
+
template <>
|
| 31 |
+
constexpr hipblasDatatype_t HipDataTypeFor<float>() {
|
| 32 |
+
return HIP_R_32F;
|
| 33 |
+
}
|
| 34 |
+
|
| 35 |
+
template <>
|
| 36 |
+
constexpr hipblasDatatype_t HipDataTypeFor<Half>() {
|
| 37 |
+
return HIP_R_16F;
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
template <>
|
| 41 |
+
constexpr hipblasDatatype_t HipDataTypeFor<BFloat16>() {
|
| 42 |
+
return HIP_R_16BF;
|
| 43 |
+
}
|
| 44 |
+
|
| 45 |
+
template <>
|
| 46 |
+
constexpr hipblasDatatype_t HipDataTypeFor<double>() {
|
| 47 |
+
return HIP_R_64F;
|
| 48 |
+
}
|
| 49 |
+
|
| 50 |
+
template <>
|
| 51 |
+
constexpr hipblasDatatype_t HipDataTypeFor<c10::Float8_e4m3fnuz>() {
|
| 52 |
+
return HIP_R_8F_E4M3_FNUZ;
|
| 53 |
+
}
|
| 54 |
+
|
| 55 |
+
template <>
|
| 56 |
+
constexpr hipblasDatatype_t HipDataTypeFor<c10::Float8_e5m2fnuz>() {
|
| 57 |
+
return HIP_R_8F_E5M2_FNUZ;
|
| 58 |
+
}
|
| 59 |
+
|
| 60 |
+
template <typename T>
|
| 61 |
+
int GetBatchFromParams(const GemmParams<T>* params) {
|
| 62 |
+
return 1;
|
| 63 |
+
}
|
| 64 |
+
|
| 65 |
+
template <typename T>
|
| 66 |
+
int GetBatchFromParams(const GemmAndBiasParams<T>* params) {
|
| 67 |
+
return 1;
|
| 68 |
+
}
|
| 69 |
+
|
| 70 |
+
template <typename T>
|
| 71 |
+
int GetBatchFromParams(const GemmStridedBatchedParams<T>* params) {
|
| 72 |
+
return params->batch;
|
| 73 |
+
}
|
| 74 |
+
|
| 75 |
+
template <typename T>
|
| 76 |
+
int GetBatchFromParams(const ScaledGemmParams<T>* params) {
|
| 77 |
+
return 1;
|
| 78 |
+
}
|
| 79 |
+
|
| 80 |
+
template <typename T>
|
| 81 |
+
int GetStrideAFromParams(const GemmParams<T>* params) {
|
| 82 |
+
return 1;
|
| 83 |
+
}
|
| 84 |
+
|
| 85 |
+
template <typename T>
|
| 86 |
+
int GetStrideAFromParams(const GemmAndBiasParams<T>* params) {
|
| 87 |
+
return 1;
|
| 88 |
+
}
|
| 89 |
+
|
| 90 |
+
template <typename T>
|
| 91 |
+
int GetStrideAFromParams(const GemmStridedBatchedParams<T>* params) {
|
| 92 |
+
return params->stride_a;
|
| 93 |
+
}
|
| 94 |
+
|
| 95 |
+
template <typename T>
|
| 96 |
+
int GetStrideAFromParams(const ScaledGemmParams<T>* params) {
|
| 97 |
+
return 1;
|
| 98 |
+
}
|
| 99 |
+
|
| 100 |
+
template <typename T>
|
| 101 |
+
int GetStrideBFromParams(const GemmParams<T>* params) {
|
| 102 |
+
return 1;
|
| 103 |
+
}
|
| 104 |
+
|
| 105 |
+
template <typename T>
|
| 106 |
+
int GetStrideBFromParams(const GemmAndBiasParams<T>* params) {
|
| 107 |
+
return 1;
|
| 108 |
+
}
|
| 109 |
+
|
| 110 |
+
template <typename T>
|
| 111 |
+
int GetStrideBFromParams(const GemmStridedBatchedParams<T>* params) {
|
| 112 |
+
return params->stride_b;
|
| 113 |
+
}
|
| 114 |
+
|
| 115 |
+
template <typename T>
|
| 116 |
+
int GetStrideBFromParams(const ScaledGemmParams<T>* params) {
|
| 117 |
+
return 1;
|
| 118 |
+
}
|
| 119 |
+
|
| 120 |
+
template <typename T>
|
| 121 |
+
int GetStrideCFromParams(const GemmParams<T>* params) {
|
| 122 |
+
return 1;
|
| 123 |
+
}
|
| 124 |
+
|
| 125 |
+
template <typename T>
|
| 126 |
+
int GetStrideCFromParams(const GemmAndBiasParams<T>* params) {
|
| 127 |
+
return 1;
|
| 128 |
+
}
|
| 129 |
+
|
| 130 |
+
template <typename T>
|
| 131 |
+
int GetStrideCFromParams(const GemmStridedBatchedParams<T>* params) {
|
| 132 |
+
return params->stride_c;
|
| 133 |
+
}
|
| 134 |
+
|
| 135 |
+
template <typename T>
|
| 136 |
+
int GetStrideCFromParams(const ScaledGemmParams<T>* params) {
|
| 137 |
+
return 1;
|
| 138 |
+
}
|
| 139 |
+
|
| 140 |
+
template <typename T>
|
| 141 |
+
float GetAlphaFromParams(const GemmParams<T>* params) {
|
| 142 |
+
return params->alpha;
|
| 143 |
+
}
|
| 144 |
+
|
| 145 |
+
template <typename T>
|
| 146 |
+
float GetAlphaFromParams(const GemmAndBiasParams<T>* params) {
|
| 147 |
+
return params->alpha;
|
| 148 |
+
}
|
| 149 |
+
|
| 150 |
+
template <typename T>
|
| 151 |
+
float GetAlphaFromParams(const GemmStridedBatchedParams<T>* params) {
|
| 152 |
+
return params->alpha;
|
| 153 |
+
}
|
| 154 |
+
|
| 155 |
+
template <typename T>
|
| 156 |
+
float GetAlphaFromParams(const ScaledGemmParams<T>* params) {
|
| 157 |
+
return 1.0;
|
| 158 |
+
}
|
| 159 |
+
|
| 160 |
+
template <typename T>
|
| 161 |
+
float GetBetaFromParams(const GemmParams<T>* params) {
|
| 162 |
+
return params->beta;
|
| 163 |
+
}
|
| 164 |
+
|
| 165 |
+
template <typename T>
|
| 166 |
+
float GetBetaFromParams(const GemmAndBiasParams<T>* params) {
|
| 167 |
+
return 0.0;
|
| 168 |
+
}
|
| 169 |
+
|
| 170 |
+
template <typename T>
|
| 171 |
+
float GetBetaFromParams(const GemmStridedBatchedParams<T>* params) {
|
| 172 |
+
return params->beta;
|
| 173 |
+
}
|
| 174 |
+
|
| 175 |
+
template <typename T>
|
| 176 |
+
float GetBetaFromParams(const ScaledGemmParams<T>* params) {
|
| 177 |
+
return 0.0;
|
| 178 |
+
}
|
| 179 |
+
|
| 180 |
+
template <typename T>
|
| 181 |
+
const void* GetAScalePointerFromParams(const GemmParams<T>* params) {
|
| 182 |
+
return nullptr;
|
| 183 |
+
}
|
| 184 |
+
|
| 185 |
+
template <typename T>
|
| 186 |
+
const void* GetAScalePointerFromParams(const GemmAndBiasParams<T>* params) {
|
| 187 |
+
return nullptr;
|
| 188 |
+
}
|
| 189 |
+
|
| 190 |
+
template <typename T>
|
| 191 |
+
const void* GetAScalePointerFromParams(const GemmStridedBatchedParams<T>* params) {
|
| 192 |
+
return nullptr;
|
| 193 |
+
}
|
| 194 |
+
|
| 195 |
+
template <typename T>
|
| 196 |
+
const void* GetAScalePointerFromParams(const ScaledGemmParams<T>* params) {
|
| 197 |
+
return params->a_scale_ptr;
|
| 198 |
+
}
|
| 199 |
+
|
| 200 |
+
template <typename T>
|
| 201 |
+
const void* GetBScalePointerFromParams(const GemmParams<T>* params) {
|
| 202 |
+
return nullptr;
|
| 203 |
+
}
|
| 204 |
+
|
| 205 |
+
template <typename T>
|
| 206 |
+
const void* GetBScalePointerFromParams(const GemmAndBiasParams<T>* params) {
|
| 207 |
+
return nullptr;
|
| 208 |
+
}
|
| 209 |
+
|
| 210 |
+
template <typename T>
|
| 211 |
+
const void* GetBScalePointerFromParams(const GemmStridedBatchedParams<T>* params) {
|
| 212 |
+
return nullptr;
|
| 213 |
+
}
|
| 214 |
+
|
| 215 |
+
template <typename T>
|
| 216 |
+
const void* GetBScalePointerFromParams(const ScaledGemmParams<T>* params) {
|
| 217 |
+
return params->b_scale_ptr;
|
| 218 |
+
}
|
| 219 |
+
|
| 220 |
+
template <typename T>
|
| 221 |
+
const void* GetDScalePointerFromParams(const GemmParams<T>* params) {
|
| 222 |
+
return nullptr;
|
| 223 |
+
}
|
| 224 |
+
|
| 225 |
+
template <typename T>
|
| 226 |
+
const void* GetDScalePointerFromParams(const GemmAndBiasParams<T>* params) {
|
| 227 |
+
return nullptr;
|
| 228 |
+
}
|
| 229 |
+
|
| 230 |
+
template <typename T>
|
| 231 |
+
const void* GetDScalePointerFromParams(const GemmStridedBatchedParams<T>* params) {
|
| 232 |
+
return nullptr;
|
| 233 |
+
}
|
| 234 |
+
|
| 235 |
+
template <typename T>
|
| 236 |
+
const void* GetDScalePointerFromParams(const ScaledGemmParams<T>* params) {
|
| 237 |
+
return params->c_scale_ptr;
|
| 238 |
+
}
|
| 239 |
+
|
| 240 |
+
template <typename T>
|
| 241 |
+
const void* GetBiasPointerFromParams(const GemmParams<T>* params) {
|
| 242 |
+
return nullptr;
|
| 243 |
+
}
|
| 244 |
+
|
| 245 |
+
template <typename T>
|
| 246 |
+
const void* GetBiasPointerFromParams(const GemmAndBiasParams<T>* params) {
|
| 247 |
+
return params->bias;
|
| 248 |
+
}
|
| 249 |
+
|
| 250 |
+
template <typename T>
|
| 251 |
+
const void* GetBiasPointerFromParams(const GemmStridedBatchedParams<T>* params) {
|
| 252 |
+
return nullptr;
|
| 253 |
+
}
|
| 254 |
+
|
| 255 |
+
template <typename T>
|
| 256 |
+
const void* GetBiasPointerFromParams(const ScaledGemmParams<T>* params) {
|
| 257 |
+
return params->bias_ptr;
|
| 258 |
+
}
|
| 259 |
+
|
| 260 |
+
template <typename T>
|
| 261 |
+
hipDataType GetBiasTypeFromParams(const GemmParams<T>* params) {
|
| 262 |
+
return HIP_R_32F;
|
| 263 |
+
}
|
| 264 |
+
|
| 265 |
+
template <typename T>
|
| 266 |
+
hipDataType GetBiasTypeFromParams(const GemmAndBiasParams<T>* params) {
|
| 267 |
+
return HipDataTypeFor<T>();
|
| 268 |
+
}
|
| 269 |
+
|
| 270 |
+
template <typename T>
|
| 271 |
+
hipDataType GetBiasTypeFromParams(const GemmStridedBatchedParams<T>* params) {
|
| 272 |
+
return HIP_R_32F;
|
| 273 |
+
}
|
| 274 |
+
|
| 275 |
+
template <typename T>
|
| 276 |
+
hipDataType GetBiasTypeFromParams(const ScaledGemmParams<T>* params) {
|
| 277 |
+
return at::cuda::ScalarTypeToCudaDataType(params->bias_dtype);
|
| 278 |
+
}
|
| 279 |
+
|
| 280 |
+
template <typename T>
|
| 281 |
+
at::cuda::blas::GEMMAndBiasActivationEpilogue GetActivationFromParams(const GemmParams<T>* params) {
|
| 282 |
+
return at::cuda::blas::GEMMAndBiasActivationEpilogue::None;
|
| 283 |
+
}
|
| 284 |
+
|
| 285 |
+
template <typename T>
|
| 286 |
+
at::cuda::blas::GEMMAndBiasActivationEpilogue GetActivationFromParams(const GemmAndBiasParams<T>* params) {
|
| 287 |
+
return params->activation;
|
| 288 |
+
}
|
| 289 |
+
|
| 290 |
+
template <typename T>
|
| 291 |
+
at::cuda::blas::GEMMAndBiasActivationEpilogue GetActivationFromParams(const GemmStridedBatchedParams<T>* params) {
|
| 292 |
+
return at::cuda::blas::GEMMAndBiasActivationEpilogue::None;
|
| 293 |
+
}
|
| 294 |
+
|
| 295 |
+
template <typename T>
|
| 296 |
+
at::cuda::blas::GEMMAndBiasActivationEpilogue GetActivationFromParams(const ScaledGemmParams<T>* params) {
|
| 297 |
+
return at::cuda::blas::GEMMAndBiasActivationEpilogue::None;
|
| 298 |
+
}
|
| 299 |
+
|
| 300 |
+
static hipblasOperation_t _hipblasOpFromChar(char op) {
|
| 301 |
+
switch (op) {
|
| 302 |
+
case 'n':
|
| 303 |
+
case 'N':
|
| 304 |
+
return HIPBLAS_OP_N;
|
| 305 |
+
case 't':
|
| 306 |
+
case 'T':
|
| 307 |
+
return HIPBLAS_OP_T;
|
| 308 |
+
case 'c':
|
| 309 |
+
case 'C':
|
| 310 |
+
return HIPBLAS_OP_C;
|
| 311 |
+
}
|
| 312 |
+
AT_ERROR(
|
| 313 |
+
"_hipblasOpFromChar input should be 't', 'n' or 'c' but got `", op, "`");
|
| 314 |
+
}
|
| 315 |
+
|
| 316 |
+
static char _charFromhipblasOp(hipblasOperation_t op) {
|
| 317 |
+
switch (op) {
|
| 318 |
+
case HIPBLAS_OP_N:
|
| 319 |
+
return 'N';
|
| 320 |
+
case HIPBLAS_OP_T:
|
| 321 |
+
return 'T';
|
| 322 |
+
case HIPBLAS_OP_C:
|
| 323 |
+
return 'C';
|
| 324 |
+
}
|
| 325 |
+
AT_ERROR(
|
| 326 |
+
"_charFromhipblasOp input should be HIPBLAS_OP_N/T/C but got `", op, "`");
|
| 327 |
+
}
|
| 328 |
+
|
| 329 |
+
static hipblasOperation_t MapLayoutToHipBlasLt(BlasOp layout) {
|
| 330 |
+
if (layout == BlasOp::N) {
|
| 331 |
+
return HIPBLAS_OP_N;
|
| 332 |
+
}
|
| 333 |
+
return HIPBLAS_OP_T;
|
| 334 |
+
}
|
| 335 |
+
|
| 336 |
+
static size_t GetHipblasltWorkspaceSize() {
|
| 337 |
+
static const char * env = getenv("HIPBLASLT_WORKSPACE_SIZE");
|
| 338 |
+
// 256MB is max workspace size allowed for hipblaslt
|
| 339 |
+
// hipblaslt-bench uses 32MB
|
| 340 |
+
// recommendation from hipblaslt author was 76MB
|
| 341 |
+
size_t workspace_size = 32*1024; // going with 32MB
|
| 342 |
+
if (env) {
|
| 343 |
+
try {
|
| 344 |
+
workspace_size = std::stoi(env);
|
| 345 |
+
} catch(std::invalid_argument const& e) {
|
| 346 |
+
TORCH_WARN("invalid HIPBLASLT_WORKSPACE_SIZE,",
|
| 347 |
+
" using default workspace size of ", workspace_size, " KiB.");
|
| 348 |
+
} catch(std::out_of_range const& e) {
|
| 349 |
+
TORCH_WARN("HIPBLASLT_WORKSPACE_SIZE out of range,",
|
| 350 |
+
" using default workspace size of ", workspace_size, " KiB.");
|
| 351 |
+
}
|
| 352 |
+
}
|
| 353 |
+
return workspace_size * 1024;
|
| 354 |
+
}
|
| 355 |
+
|
| 356 |
+
template <typename T, cublasStatus_t (*destructor)(T*)>
|
| 357 |
+
struct HipBlasLtDeleter {
|
| 358 |
+
void operator()(T* x) {
|
| 359 |
+
if (x != nullptr) {
|
| 360 |
+
TORCH_CUDABLAS_CHECK(destructor(x));
|
| 361 |
+
}
|
| 362 |
+
}
|
| 363 |
+
};
|
| 364 |
+
|
| 365 |
+
template <typename T, hipblasStatus_t (*destructor)(T*)>
|
| 366 |
+
class HipBlasLtDescriptor {
|
| 367 |
+
public:
|
| 368 |
+
T* descriptor() const {
|
| 369 |
+
return descriptor_.get();
|
| 370 |
+
}
|
| 371 |
+
T* descriptor() {
|
| 372 |
+
return descriptor_.get();
|
| 373 |
+
}
|
| 374 |
+
|
| 375 |
+
protected:
|
| 376 |
+
std::unique_ptr<T, HipBlasLtDeleter<T, destructor>> descriptor_;
|
| 377 |
+
};
|
| 378 |
+
|
| 379 |
+
class HipBlasLtMatmulDescriptor : public HipBlasLtDescriptor<
|
| 380 |
+
hipblasLtMatmulDescOpaque_t,
|
| 381 |
+
&hipblasLtMatmulDescDestroy> {
|
| 382 |
+
public:
|
| 383 |
+
HipBlasLtMatmulDescriptor(
|
| 384 |
+
hipblasComputeType_t compute_type,
|
| 385 |
+
hipDataType scale_type) {
|
| 386 |
+
hipblasLtMatmulDesc_t raw_descriptor = nullptr;
|
| 387 |
+
TORCH_HIPBLASLT_CHECK(
|
| 388 |
+
hipblasLtMatmulDescCreate(&raw_descriptor, compute_type, scale_type));
|
| 389 |
+
descriptor_.reset(raw_descriptor);
|
| 390 |
+
}
|
| 391 |
+
template <typename T>
|
| 392 |
+
inline void setAttribute(hipblasLtMatmulDescAttributes_t attr, const T value) {
|
| 393 |
+
TORCH_HIPBLASLT_CHECK(::hipblasLtMatmulDescSetAttribute(descriptor(), attr, &value, sizeof(T)));
|
| 394 |
+
}
|
| 395 |
+
};
|
| 396 |
+
|
| 397 |
+
template <typename AT, typename BT, typename CT, BlasOp ALayout, BlasOp BLayout, typename ParamsT>
|
| 398 |
+
class HipblasltGemmOp : public Callable<ParamsT> {
|
| 399 |
+
public:
|
| 400 |
+
HipblasltGemmOp(hipblasLtMatmulAlgo_t algo) : algo_{algo} {}
|
| 401 |
+
|
| 402 |
+
TuningStatus Call(const ParamsT* params) override {
|
| 403 |
+
hipblasOperation_t transa_outer = MapLayoutToHipBlasLt(ALayout);
|
| 404 |
+
hipblasOperation_t transb_outer = MapLayoutToHipBlasLt(BLayout);
|
| 405 |
+
auto a_datatype = HipDataTypeFor<AT>();
|
| 406 |
+
auto b_datatype = HipDataTypeFor<BT>();
|
| 407 |
+
auto in_out_datatype = HipDataTypeFor<CT>();
|
| 408 |
+
auto opa = _hipblasOpFromChar(params->transa);
|
| 409 |
+
auto opb = _hipblasOpFromChar(params->transb);
|
| 410 |
+
|
| 411 |
+
TORCH_CHECK(transa_outer == opa && transb_outer == opb, "trans mismatch, shouldn't happen");
|
| 412 |
+
|
| 413 |
+
float alpha = GetAlphaFromParams<CT>(params);
|
| 414 |
+
float beta = GetBetaFromParams<CT>(params);
|
| 415 |
+
|
| 416 |
+
hipblasLtMatrixLayout_t mat_a, mat_b, mat_c;
|
| 417 |
+
if (opa == HIPBLAS_OP_N) {
|
| 418 |
+
TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutCreate(&mat_a, a_datatype, params->m, params->k, params->lda));
|
| 419 |
+
}
|
| 420 |
+
else {
|
| 421 |
+
TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutCreate(&mat_a, a_datatype, params->k, params->m, params->lda));
|
| 422 |
+
}
|
| 423 |
+
if (opb == HIPBLAS_OP_N) {
|
| 424 |
+
TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutCreate(&mat_b, b_datatype, params->k, params->n, params->ldb));
|
| 425 |
+
}
|
| 426 |
+
else {
|
| 427 |
+
TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutCreate(&mat_b, b_datatype, params->n, params->k, params->ldb));
|
| 428 |
+
}
|
| 429 |
+
TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutCreate(&mat_c, in_out_datatype, params->m, params->n, params->ldc));
|
| 430 |
+
|
| 431 |
+
// specific to batched gemmm
|
| 432 |
+
int batch = GetBatchFromParams<CT>(params);
|
| 433 |
+
if (batch > 1) {
|
| 434 |
+
int64_t stride_a = GetStrideAFromParams<CT>(params);
|
| 435 |
+
int64_t stride_b = GetStrideBFromParams<CT>(params);
|
| 436 |
+
int64_t stride_c = GetStrideCFromParams<CT>(params);
|
| 437 |
+
TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutSetAttribute(
|
| 438 |
+
mat_a, HIPBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch, sizeof(batch)));
|
| 439 |
+
TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutSetAttribute(
|
| 440 |
+
mat_a, HIPBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &stride_a, sizeof(stride_a)));
|
| 441 |
+
TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutSetAttribute(
|
| 442 |
+
mat_b, HIPBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch, sizeof(batch)));
|
| 443 |
+
TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutSetAttribute(
|
| 444 |
+
mat_b, HIPBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &stride_b, sizeof(stride_b)));
|
| 445 |
+
TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutSetAttribute(
|
| 446 |
+
mat_c, HIPBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch, sizeof(batch)));
|
| 447 |
+
TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutSetAttribute(
|
| 448 |
+
mat_c, HIPBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &stride_c, sizeof(stride_c)));
|
| 449 |
+
}
|
| 450 |
+
|
| 451 |
+
HipBlasLtMatmulDescriptor matmul(HIPBLAS_COMPUTE_32F, HIP_R_32F);
|
| 452 |
+
matmul.setAttribute(HIPBLASLT_MATMUL_DESC_TRANSA, opa);
|
| 453 |
+
matmul.setAttribute(HIPBLASLT_MATMUL_DESC_TRANSB, opb);
|
| 454 |
+
|
| 455 |
+
// specific to scaled gemm
|
| 456 |
+
const void* mat1_scale_ptr = GetAScalePointerFromParams<CT>(params);
|
| 457 |
+
const void* mat2_scale_ptr = GetBScalePointerFromParams<CT>(params);
|
| 458 |
+
const void* result_scale_ptr = GetDScalePointerFromParams<CT>(params);
|
| 459 |
+
if (mat1_scale_ptr && mat2_scale_ptr) {
|
| 460 |
+
matmul.setAttribute(HIPBLASLT_MATMUL_DESC_A_SCALE_POINTER, mat1_scale_ptr);
|
| 461 |
+
matmul.setAttribute(HIPBLASLT_MATMUL_DESC_B_SCALE_POINTER, mat2_scale_ptr);
|
| 462 |
+
}
|
| 463 |
+
if (result_scale_ptr) {
|
| 464 |
+
matmul.setAttribute(HIPBLASLT_MATMUL_DESC_D_SCALE_POINTER, result_scale_ptr);
|
| 465 |
+
}
|
| 466 |
+
|
| 467 |
+
const void* bias_ptr = GetBiasPointerFromParams<CT>(params);
|
| 468 |
+
auto bias_datatype = GetBiasTypeFromParams<CT>(params);
|
| 469 |
+
if (bias_ptr) {
|
| 470 |
+
matmul.setAttribute(HIPBLASLT_MATMUL_DESC_BIAS_POINTER, bias_ptr);
|
| 471 |
+
matmul.setAttribute(HIPBLASLT_MATMUL_DESC_BIAS_DATA_TYPE, bias_datatype);
|
| 472 |
+
auto activation = GetActivationFromParams<CT>(params);
|
| 473 |
+
if (activation == at::cuda::blas::GEMMAndBiasActivationEpilogue::RELU) {
|
| 474 |
+
matmul.setAttribute(HIPBLASLT_MATMUL_DESC_EPILOGUE, HIPBLASLT_EPILOGUE_RELU_BIAS);
|
| 475 |
+
}
|
| 476 |
+
else if (activation == at::cuda::blas::GEMMAndBiasActivationEpilogue::GELU) {
|
| 477 |
+
matmul.setAttribute(HIPBLASLT_MATMUL_DESC_EPILOGUE, HIPBLASLT_EPILOGUE_GELU_BIAS);
|
| 478 |
+
}
|
| 479 |
+
else {
|
| 480 |
+
matmul.setAttribute(HIPBLASLT_MATMUL_DESC_EPILOGUE, HIPBLASLT_EPILOGUE_BIAS);
|
| 481 |
+
}
|
| 482 |
+
}
|
| 483 |
+
|
| 484 |
+
size_t workspace_size = GetHipblasltWorkspaceSize();
|
| 485 |
+
|
| 486 |
+
auto op_handle = at::cuda::getCurrentCUDABlasLtHandle();
|
| 487 |
+
|
| 488 |
+
size_t ret_workspace_size = 0;
|
| 489 |
+
auto status = hipblaslt_ext::matmulIsAlgoSupported(op_handle,
|
| 490 |
+
matmul.descriptor(),
|
| 491 |
+
&alpha,
|
| 492 |
+
mat_a,
|
| 493 |
+
mat_b,
|
| 494 |
+
&beta,
|
| 495 |
+
mat_c,
|
| 496 |
+
mat_c,
|
| 497 |
+
algo_,
|
| 498 |
+
ret_workspace_size);
|
| 499 |
+
|
| 500 |
+
if (status == HIPBLAS_STATUS_SUCCESS) {
|
| 501 |
+
if (ret_workspace_size >= workspace_size) {
|
| 502 |
+
return FAIL;
|
| 503 |
+
}
|
| 504 |
+
}
|
| 505 |
+
else {
|
| 506 |
+
return FAIL;
|
| 507 |
+
}
|
| 508 |
+
|
| 509 |
+
void* workspace_buffer = nullptr;
|
| 510 |
+
if (workspace_size > 0) {
|
| 511 |
+
workspace_buffer = c10::cuda::CUDACachingAllocator::raw_alloc(workspace_size);
|
| 512 |
+
}
|
| 513 |
+
|
| 514 |
+
TORCH_HIPBLASLT_CHECK(hipblasLtMatmul(op_handle,
|
| 515 |
+
matmul.descriptor(),
|
| 516 |
+
&alpha,
|
| 517 |
+
params->a,
|
| 518 |
+
mat_a,
|
| 519 |
+
params->b,
|
| 520 |
+
mat_b,
|
| 521 |
+
&beta,
|
| 522 |
+
params->c,
|
| 523 |
+
mat_c,
|
| 524 |
+
params->c,
|
| 525 |
+
mat_c,
|
| 526 |
+
&algo_,
|
| 527 |
+
workspace_buffer,
|
| 528 |
+
workspace_size,
|
| 529 |
+
at::cuda::getCurrentCUDAStream()));
|
| 530 |
+
|
| 531 |
+
//TORCH_HIPBLASLT_CHECK(hipblasLtMatmulDescDestroy(matmul));
|
| 532 |
+
TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutDestroy(mat_a));
|
| 533 |
+
TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutDestroy(mat_b));
|
| 534 |
+
TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutDestroy(mat_c));
|
| 535 |
+
if (workspace_size > 0) {
|
| 536 |
+
c10::cuda::CUDACachingAllocator::raw_delete(workspace_buffer);
|
| 537 |
+
}
|
| 538 |
+
return OK;
|
| 539 |
+
}
|
| 540 |
+
|
| 541 |
+
private:
|
| 542 |
+
hipblasLtMatmulAlgo_t algo_;
|
| 543 |
+
};
|
| 544 |
+
|
| 545 |
+
template <typename AT, typename BT, typename CT, BlasOp ALayout, BlasOp BLayout, typename ParamsT>
|
| 546 |
+
auto GetHipBlasLtTypeStringAndOps() {
|
| 547 |
+
hipblasOperation_t transa_outer = MapLayoutToHipBlasLt(ALayout);
|
| 548 |
+
hipblasOperation_t transb_outer = MapLayoutToHipBlasLt(BLayout);
|
| 549 |
+
auto a_datatype = HipDataTypeFor<AT>();
|
| 550 |
+
auto b_datatype = HipDataTypeFor<BT>();
|
| 551 |
+
auto in_out_datatype = HipDataTypeFor<CT>();
|
| 552 |
+
std::vector<hipblasLtMatmulHeuristicResult_t> heuristic_result;
|
| 553 |
+
|
| 554 |
+
hipblasLtHandle_t handle;
|
| 555 |
+
TORCH_HIPBLASLT_CHECK(hipblasLtCreate(&handle));
|
| 556 |
+
TORCH_HIPBLASLT_CHECK(hipblaslt_ext::getAllAlgos(handle,
|
| 557 |
+
hipblaslt_ext::GemmType::HIPBLASLT_GEMM,
|
| 558 |
+
transa_outer,
|
| 559 |
+
transb_outer,
|
| 560 |
+
a_datatype,
|
| 561 |
+
b_datatype,
|
| 562 |
+
in_out_datatype,
|
| 563 |
+
in_out_datatype,
|
| 564 |
+
HIPBLAS_COMPUTE_32F,
|
| 565 |
+
heuristic_result));
|
| 566 |
+
TORCH_HIPBLASLT_CHECK(hipblasLtDestroy(handle));
|
| 567 |
+
|
| 568 |
+
// Sort heuristic_result by algo index to make sure the order of returned algos is deterministic.
|
| 569 |
+
std::sort(heuristic_result.begin(),
|
| 570 |
+
heuristic_result.end(),
|
| 571 |
+
[](hipblasLtMatmulHeuristicResult_t& a, hipblasLtMatmulHeuristicResult_t& b) {
|
| 572 |
+
return hipblaslt_ext::getIndexFromAlgo(a.algo) < hipblaslt_ext::getIndexFromAlgo(b.algo);
|
| 573 |
+
});
|
| 574 |
+
|
| 575 |
+
int returned_algo_count = heuristic_result.size();
|
| 576 |
+
std::vector<std::pair<std::string, std::unique_ptr<Callable<ParamsT>>>> ret;
|
| 577 |
+
for (int i = 0; i < returned_algo_count; i++) {
|
| 578 |
+
auto algo = heuristic_result[i].algo;
|
| 579 |
+
int algo_index = hipblaslt_ext::getIndexFromAlgo(algo);
|
| 580 |
+
auto callable = std::make_unique<HipblasltGemmOp<AT, BT, CT, ALayout, BLayout, ParamsT>>(algo);
|
| 581 |
+
std::string type_string = c10::str(
|
| 582 |
+
"Gemm_Hipblaslt_", _charFromhipblasOp(transa_outer), _charFromhipblasOp(transb_outer), "_", algo_index);
|
| 583 |
+
ret.emplace_back(type_string, std::move(callable));
|
| 584 |
+
}
|
| 585 |
+
|
| 586 |
+
return ret;
|
| 587 |
+
}
|
| 588 |
+
|
| 589 |
+
template <typename T, BlasOp ALayout, BlasOp BLayout>
|
| 590 |
+
auto GetHipBlasLtGemmTypeStringAndOps() {
|
| 591 |
+
return GetHipBlasLtTypeStringAndOps<T, T, T, ALayout, BLayout, GemmParams<T>>();
|
| 592 |
+
}
|
| 593 |
+
|
| 594 |
+
template <typename T, BlasOp ALayout, BlasOp BLayout>
|
| 595 |
+
auto GetHipBlasLtGemmAndBiasTypeStringAndOps() {
|
| 596 |
+
return GetHipBlasLtTypeStringAndOps<T, T, T, ALayout, BLayout, GemmAndBiasParams<T>>();
|
| 597 |
+
}
|
| 598 |
+
|
| 599 |
+
template <typename T, BlasOp ALayout, BlasOp BLayout>
|
| 600 |
+
auto GetHipBlasLtGemmStridedBatchedTypeStringAndOps() {
|
| 601 |
+
return GetHipBlasLtTypeStringAndOps<T, T, T, ALayout, BLayout, GemmStridedBatchedParams<T>>();
|
| 602 |
+
}
|
| 603 |
+
|
| 604 |
+
template <typename AT, typename BT, typename CT, BlasOp ALayout, BlasOp BLayout>
|
| 605 |
+
auto GetHipBlasLtScaledGemmTypeStringAndOps() {
|
| 606 |
+
return GetHipBlasLtTypeStringAndOps<AT, BT, CT, ALayout, BLayout, ScaledGemmParams<CT>>();
|
| 607 |
+
}
|
| 608 |
+
|
| 609 |
+
#undef TORCH_HIPBLASLT_CHECK
|
| 610 |
+
|
| 611 |
+
} // namespace at::cuda::tunable
|
.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/tunable/GemmRocblas.h
ADDED
|
@@ -0,0 +1,275 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Copyright (c) Microsoft Corporation. All rights reserved.
|
| 2 |
+
// Licensed under the MIT License.
|
| 3 |
+
|
| 4 |
+
#pragma once
|
| 5 |
+
|
| 6 |
+
#include <ATen/cuda/CUDAContext.h>
|
| 7 |
+
#include <ATen/cuda/tunable/TunableOp.h>
|
| 8 |
+
#include <ATen/cuda/tunable/GemmCommon.h>
|
| 9 |
+
#include <c10/util/StringUtil.h>
|
| 10 |
+
|
| 11 |
+
#define ROCBLAS_BETA_FEATURES_API
|
| 12 |
+
#include <rocblas/rocblas.h>
|
| 13 |
+
|
| 14 |
+
#define TORCH_ROCBLAS_CHECK(EXPR) \
|
| 15 |
+
do { \
|
| 16 |
+
rocblas_status __err = EXPR; \
|
| 17 |
+
TORCH_CHECK(__err == rocblas_status_success, \
|
| 18 |
+
"rocblas error: ", \
|
| 19 |
+
rocblas_status_to_string(__err), \
|
| 20 |
+
" when calling `" #EXPR "`"); \
|
| 21 |
+
} while (0)
|
| 22 |
+
|
| 23 |
+
namespace at::cuda::tunable {
|
| 24 |
+
|
| 25 |
+
template <typename T>
|
| 26 |
+
constexpr rocblas_datatype RocBlasDataTypeFor();
|
| 27 |
+
|
| 28 |
+
template <>
|
| 29 |
+
constexpr rocblas_datatype RocBlasDataTypeFor<float>() {
|
| 30 |
+
return rocblas_datatype_f32_r;
|
| 31 |
+
}
|
| 32 |
+
|
| 33 |
+
template <>
|
| 34 |
+
constexpr rocblas_datatype RocBlasDataTypeFor<double>() {
|
| 35 |
+
return rocblas_datatype_f64_r;
|
| 36 |
+
}
|
| 37 |
+
|
| 38 |
+
template <>
|
| 39 |
+
constexpr rocblas_datatype RocBlasDataTypeFor<Half>() {
|
| 40 |
+
return rocblas_datatype_f16_r;
|
| 41 |
+
}
|
| 42 |
+
|
| 43 |
+
template <>
|
| 44 |
+
constexpr rocblas_datatype RocBlasDataTypeFor<BFloat16>() {
|
| 45 |
+
return rocblas_datatype_bf16_r;
|
| 46 |
+
}
|
| 47 |
+
|
| 48 |
+
template <>
|
| 49 |
+
constexpr rocblas_datatype RocBlasDataTypeFor<c10::complex<float>>() {
|
| 50 |
+
return rocblas_datatype_f32_c;
|
| 51 |
+
}
|
| 52 |
+
|
| 53 |
+
template <>
|
| 54 |
+
constexpr rocblas_datatype RocBlasDataTypeFor<c10::complex<double>>() {
|
| 55 |
+
return rocblas_datatype_f64_c;
|
| 56 |
+
}
|
| 57 |
+
|
| 58 |
+
template <typename T>
|
| 59 |
+
constexpr rocblas_datatype RocBlasComputeTypeFor();
|
| 60 |
+
|
| 61 |
+
template <>
|
| 62 |
+
constexpr rocblas_datatype RocBlasComputeTypeFor<float>() {
|
| 63 |
+
return rocblas_datatype_f32_r;
|
| 64 |
+
}
|
| 65 |
+
|
| 66 |
+
template <>
|
| 67 |
+
constexpr rocblas_datatype RocBlasComputeTypeFor<double>() {
|
| 68 |
+
return rocblas_datatype_f64_r;
|
| 69 |
+
}
|
| 70 |
+
|
| 71 |
+
template <>
|
| 72 |
+
constexpr rocblas_datatype RocBlasComputeTypeFor<Half>() {
|
| 73 |
+
// Note that we're returning the _compute_ type for a given datatype.
|
| 74 |
+
// As of 12/2022, using compute type FP16 for 16-bit floats was much
|
| 75 |
+
// slower than using compute type FP32. So we use FP32 compute even for
|
| 76 |
+
// FP16 datatypes. This is how GEMM is implemented even in the function
|
| 77 |
+
// rocblasGemmHelper (see fpgeneric.h)
|
| 78 |
+
return rocblas_datatype_f32_r;
|
| 79 |
+
}
|
| 80 |
+
|
| 81 |
+
template <>
|
| 82 |
+
constexpr rocblas_datatype RocBlasComputeTypeFor<BFloat16>() {
|
| 83 |
+
// Note that we're returning the _compute_ type for a given datatype.
|
| 84 |
+
// As of 12/2022, using compute type FP16 for 16-bit floats was much
|
| 85 |
+
// slower than using compute type FP32. So we use FP32 compute even for
|
| 86 |
+
// BF16 datatypes. This is how GEMM is implemented even in the function
|
| 87 |
+
// rocblasGemmHelper (see fpgeneric.h)
|
| 88 |
+
return rocblas_datatype_f32_r;
|
| 89 |
+
}
|
| 90 |
+
|
| 91 |
+
template <>
|
| 92 |
+
constexpr rocblas_datatype RocBlasComputeTypeFor<c10::complex<float>>() {
|
| 93 |
+
return rocblas_datatype_f32_c;
|
| 94 |
+
}
|
| 95 |
+
|
| 96 |
+
template <>
|
| 97 |
+
constexpr rocblas_datatype RocBlasComputeTypeFor<c10::complex<double>>() {
|
| 98 |
+
return rocblas_datatype_f64_c;
|
| 99 |
+
}
|
| 100 |
+
|
| 101 |
+
template <typename T>
|
| 102 |
+
auto DoCastForHalfOrBfloat16(const T fp) {
|
| 103 |
+
return fp;
|
| 104 |
+
}
|
| 105 |
+
|
| 106 |
+
template <>
|
| 107 |
+
inline auto DoCastForHalfOrBfloat16<Half>(const Half fp) {
|
| 108 |
+
// alpha and beta should be the same as compute_type, in Half case it is float.
|
| 109 |
+
float h = fp;
|
| 110 |
+
return h;
|
| 111 |
+
}
|
| 112 |
+
|
| 113 |
+
template <>
|
| 114 |
+
inline auto DoCastForHalfOrBfloat16<BFloat16>(const BFloat16 fp) {
|
| 115 |
+
// alpha and beta should be the same as compute_type, in bfloat16 case it is float.
|
| 116 |
+
float h = fp;
|
| 117 |
+
return h;
|
| 118 |
+
}
|
| 119 |
+
|
| 120 |
+
static rocblas_operation _rocblasOpFromChar(char op) {
|
| 121 |
+
switch (op) {
|
| 122 |
+
case 'n':
|
| 123 |
+
case 'N':
|
| 124 |
+
return rocblas_operation_none;
|
| 125 |
+
case 't':
|
| 126 |
+
case 'T':
|
| 127 |
+
return rocblas_operation_transpose;
|
| 128 |
+
case 'c':
|
| 129 |
+
case 'C':
|
| 130 |
+
return rocblas_operation_conjugate_transpose;
|
| 131 |
+
}
|
| 132 |
+
AT_ERROR(
|
| 133 |
+
"_rocblasOpFromChar input should be 't', 'n' or 'c' but got `", op, "`");
|
| 134 |
+
}
|
| 135 |
+
|
| 136 |
+
template <typename T>
|
| 137 |
+
class RocblasGemmOp : public Callable<GemmParams<T>> {
|
| 138 |
+
public:
|
| 139 |
+
RocblasGemmOp(int solution) : solution_{solution} {}
|
| 140 |
+
|
| 141 |
+
TuningStatus Call(const GemmParams<T>* params) override {
|
| 142 |
+
auto input_output_type = RocBlasDataTypeFor<T>();
|
| 143 |
+
auto compute_type = RocBlasComputeTypeFor<T>();
|
| 144 |
+
auto h_a = DoCastForHalfOrBfloat16(params->alpha);
|
| 145 |
+
auto h_b = DoCastForHalfOrBfloat16(params->beta);
|
| 146 |
+
auto status = rocblas_gemm_ex(
|
| 147 |
+
(rocblas_handle)at::cuda::getCurrentCUDABlasHandle(),
|
| 148 |
+
_rocblasOpFromChar(params->transa),
|
| 149 |
+
_rocblasOpFromChar(params->transb),
|
| 150 |
+
params->m, params->n, params->k,
|
| 151 |
+
&h_a,
|
| 152 |
+
params->a, input_output_type, params->lda,
|
| 153 |
+
params->b, input_output_type, params->ldb,
|
| 154 |
+
&h_b,
|
| 155 |
+
params->c, input_output_type, params->ldc,
|
| 156 |
+
params->c, input_output_type, params->ldc,
|
| 157 |
+
compute_type,
|
| 158 |
+
rocblas_gemm_algo_solution_index,
|
| 159 |
+
solution_,
|
| 160 |
+
rocblas_gemm_flags_none);
|
| 161 |
+
if (status != rocblas_status_success) {
|
| 162 |
+
return FAIL;
|
| 163 |
+
}
|
| 164 |
+
return OK;
|
| 165 |
+
}
|
| 166 |
+
|
| 167 |
+
private:
|
| 168 |
+
int solution_;
|
| 169 |
+
};
|
| 170 |
+
|
| 171 |
+
template <typename T>
|
| 172 |
+
auto GetRocBlasGemmTypeStringAndOps() {
|
| 173 |
+
rocblas_handle handle = (rocblas_handle)at::cuda::getCurrentCUDABlasHandle();
|
| 174 |
+
int solution_size;
|
| 175 |
+
auto input_output_type = RocBlasDataTypeFor<T>();
|
| 176 |
+
auto compute_type = RocBlasComputeTypeFor<T>();
|
| 177 |
+
// Get the number of available solutions
|
| 178 |
+
TORCH_ROCBLAS_CHECK(rocblas_gemm_ex_get_solutions_by_type(handle,
|
| 179 |
+
input_output_type,
|
| 180 |
+
input_output_type,
|
| 181 |
+
compute_type,
|
| 182 |
+
rocblas_gemm_flags_none,
|
| 183 |
+
nullptr,
|
| 184 |
+
&solution_size));
|
| 185 |
+
std::vector<int> solutions(solution_size);
|
| 186 |
+
// Get the list of available solutions
|
| 187 |
+
TORCH_ROCBLAS_CHECK(rocblas_gemm_ex_get_solutions_by_type(handle,
|
| 188 |
+
input_output_type,
|
| 189 |
+
input_output_type,
|
| 190 |
+
compute_type,
|
| 191 |
+
rocblas_gemm_flags_none,
|
| 192 |
+
solutions.data(),
|
| 193 |
+
&solution_size));
|
| 194 |
+
// Sort the solutions in ascending order to make the solution vector deterministic across runs
|
| 195 |
+
std::sort(solutions.begin(), solutions.end());
|
| 196 |
+
|
| 197 |
+
std::vector<std::pair<std::string, std::unique_ptr<Callable<GemmParams<T>>>>> ret;
|
| 198 |
+
for (size_t i = 0; i < solutions.size(); ++i) {
|
| 199 |
+
auto callable = std::make_unique<RocblasGemmOp<T>>(solutions[i]);
|
| 200 |
+
ret.emplace_back(std::make_pair(c10::str("Gemm_Rocblas_", solutions[i]), std::move(callable)));
|
| 201 |
+
}
|
| 202 |
+
return ret;
|
| 203 |
+
}
|
| 204 |
+
|
| 205 |
+
template <typename T>
|
| 206 |
+
class RocblasGemmStridedBatchedOp : public Callable<GemmStridedBatchedParams<T>> {
|
| 207 |
+
public:
|
| 208 |
+
RocblasGemmStridedBatchedOp(int solution) : solution_{solution} {}
|
| 209 |
+
|
| 210 |
+
TuningStatus Call(const GemmStridedBatchedParams<T>* params) override {
|
| 211 |
+
auto input_output_type = RocBlasDataTypeFor<T>();
|
| 212 |
+
auto compute_type = RocBlasComputeTypeFor<T>();
|
| 213 |
+
auto h_a = DoCastForHalfOrBfloat16(params->alpha);
|
| 214 |
+
auto h_b = DoCastForHalfOrBfloat16(params->beta);
|
| 215 |
+
auto status = rocblas_gemm_strided_batched_ex(
|
| 216 |
+
(rocblas_handle)at::cuda::getCurrentCUDABlasHandle(),
|
| 217 |
+
_rocblasOpFromChar(params->transa),
|
| 218 |
+
_rocblasOpFromChar(params->transb),
|
| 219 |
+
params->m, params->n, params->k,
|
| 220 |
+
&h_a,
|
| 221 |
+
params->a, input_output_type, params->lda, params->stride_a,
|
| 222 |
+
params->b, input_output_type, params->ldb, params->stride_b,
|
| 223 |
+
&h_b,
|
| 224 |
+
params->c, input_output_type, params->ldc, params->stride_c,
|
| 225 |
+
params->c, input_output_type, params->ldc, params->stride_c,
|
| 226 |
+
params->batch,
|
| 227 |
+
compute_type,
|
| 228 |
+
rocblas_gemm_algo_solution_index,
|
| 229 |
+
solution_,
|
| 230 |
+
rocblas_gemm_flags_none);
|
| 231 |
+
if (status != rocblas_status_success) {
|
| 232 |
+
return FAIL;
|
| 233 |
+
}
|
| 234 |
+
return OK;
|
| 235 |
+
}
|
| 236 |
+
|
| 237 |
+
private:
|
| 238 |
+
int solution_;
|
| 239 |
+
};
|
| 240 |
+
|
| 241 |
+
template <typename T>
|
| 242 |
+
auto GetRocBlasGemmStridedBatchedTypeStringAndOps() {
|
| 243 |
+
rocblas_handle handle = (rocblas_handle)at::cuda::getCurrentCUDABlasHandle();
|
| 244 |
+
int solution_size;
|
| 245 |
+
auto input_output_type = RocBlasDataTypeFor<T>();
|
| 246 |
+
auto compute_type = RocBlasComputeTypeFor<T>();
|
| 247 |
+
// Get the number of available solutions
|
| 248 |
+
TORCH_ROCBLAS_CHECK(rocblas_gemm_ex_get_solutions_by_type(handle,
|
| 249 |
+
input_output_type,
|
| 250 |
+
input_output_type,
|
| 251 |
+
compute_type,
|
| 252 |
+
rocblas_gemm_flags_none,
|
| 253 |
+
nullptr,
|
| 254 |
+
&solution_size));
|
| 255 |
+
std::vector<int> solutions(solution_size);
|
| 256 |
+
// Get the list of available solutions
|
| 257 |
+
TORCH_ROCBLAS_CHECK(rocblas_gemm_ex_get_solutions_by_type(handle,
|
| 258 |
+
input_output_type,
|
| 259 |
+
input_output_type,
|
| 260 |
+
compute_type,
|
| 261 |
+
rocblas_gemm_flags_none,
|
| 262 |
+
solutions.data(),
|
| 263 |
+
&solution_size));
|
| 264 |
+
// Sort the solutions in ascending order to make the solution vector deterministic across runs
|
| 265 |
+
std::sort(solutions.begin(), solutions.end());
|
| 266 |
+
|
| 267 |
+
std::vector<std::pair<std::string, std::unique_ptr<Callable<GemmStridedBatchedParams<T>>>>> ret;
|
| 268 |
+
for (size_t i = 0; i < solutions.size(); ++i) {
|
| 269 |
+
auto callable = std::make_unique<RocblasGemmStridedBatchedOp<T>>(solutions[i]);
|
| 270 |
+
ret.emplace_back(std::make_pair(c10::str("Gemm_Rocblas_", solutions[i]), std::move(callable)));
|
| 271 |
+
}
|
| 272 |
+
return ret;
|
| 273 |
+
}
|
| 274 |
+
|
| 275 |
+
} // namespace at::cuda::tunable
|
.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/tunable/StreamTimer.h
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Original TunableOp is from onnxruntime.
|
| 2 |
+
// https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/core/framework/tunable.h
|
| 3 |
+
// https://github.com/microsoft/onnxruntime/tree/main/onnxruntime/core/providers/rocm/tunable
|
| 4 |
+
// Copyright (c) Microsoft Corporation.
|
| 5 |
+
// Licensed under the MIT license.
|
| 6 |
+
//
|
| 7 |
+
// Adapting TunableOp into PyTorch
|
| 8 |
+
// Copyright (c) Advanced Micro Devices, Inc.
|
| 9 |
+
//
|
| 10 |
+
#pragma once
|
| 11 |
+
|
| 12 |
+
#include <cuda_runtime.h>
|
| 13 |
+
|
| 14 |
+
#include <ATen/cuda/tunable/Tunable.h>
|
| 15 |
+
|
| 16 |
+
namespace at::cuda::tunable {
|
| 17 |
+
|
| 18 |
+
class StreamTimer : public ITimer {
|
| 19 |
+
public:
|
| 20 |
+
StreamTimer();
|
| 21 |
+
virtual ~StreamTimer() override;
|
| 22 |
+
|
| 23 |
+
void Start() override;
|
| 24 |
+
|
| 25 |
+
void End() override;
|
| 26 |
+
|
| 27 |
+
float Duration() override;
|
| 28 |
+
|
| 29 |
+
private:
|
| 30 |
+
cudaEvent_t start_;
|
| 31 |
+
cudaEvent_t end_;
|
| 32 |
+
};
|
| 33 |
+
|
| 34 |
+
} // namespace at::cuda::tunable
|
.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/tunable/Tunable.h
ADDED
|
@@ -0,0 +1,246 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Original TunableOp is from onnxruntime.
|
| 2 |
+
// https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/core/framework/tunable.h
|
| 3 |
+
// https://github.com/microsoft/onnxruntime/tree/main/onnxruntime/core/providers/rocm/tunable
|
| 4 |
+
// Copyright (c) Microsoft Corporation.
|
| 5 |
+
// Licensed under the MIT license.
|
| 6 |
+
//
|
| 7 |
+
// Adapting TunableOp into PyTorch
|
| 8 |
+
// Copyright (c) Advanced Micro Devices, Inc.
|
| 9 |
+
//
|
| 10 |
+
#pragma once
|
| 11 |
+
|
| 12 |
+
#include <c10/util/CallOnce.h>
|
| 13 |
+
|
| 14 |
+
#include <fstream>
|
| 15 |
+
#include <functional>
|
| 16 |
+
#include <iostream>
|
| 17 |
+
#include <memory>
|
| 18 |
+
#include <mutex>
|
| 19 |
+
#include <string>
|
| 20 |
+
#include <type_traits>
|
| 21 |
+
#include <unordered_map>
|
| 22 |
+
#include <utility>
|
| 23 |
+
#include <vector>
|
| 24 |
+
|
| 25 |
+
namespace at::cuda::tunable {
|
| 26 |
+
|
| 27 |
+
namespace detail {
|
| 28 |
+
|
| 29 |
+
struct MaybeDelete {
|
| 30 |
+
bool owns_pointer;
|
| 31 |
+
void operator()(std::ostream* os) const { if (owns_pointer) delete os; }
|
| 32 |
+
};
|
| 33 |
+
|
| 34 |
+
using OstreamPtr = std::unique_ptr<std::ostream, MaybeDelete>;
|
| 35 |
+
|
| 36 |
+
static OstreamPtr get_stream(std::string filename) {
|
| 37 |
+
if (filename.compare("out") == 0) {
|
| 38 |
+
return OstreamPtr { &std::cout, MaybeDelete {false} };
|
| 39 |
+
}
|
| 40 |
+
else if (filename.compare("err") == 0) {
|
| 41 |
+
return OstreamPtr { &std::cerr, MaybeDelete {false} };
|
| 42 |
+
}
|
| 43 |
+
else {
|
| 44 |
+
return OstreamPtr { new std::ofstream {filename.c_str()}, MaybeDelete {true} };
|
| 45 |
+
}
|
| 46 |
+
}
|
| 47 |
+
|
| 48 |
+
}
|
| 49 |
+
|
| 50 |
+
static void TunableLog(int level, const std::string& msg) {
|
| 51 |
+
static const char *env_file = getenv("PYTORCH_TUNABLEOP_VERBOSE_FILENAME");
|
| 52 |
+
static const char *env_verbose = getenv("PYTORCH_TUNABLEOP_VERBOSE");
|
| 53 |
+
static int level_user = env_verbose ? atoi(env_verbose) : 0;
|
| 54 |
+
static auto streamptr = detail::get_stream(env_file ? env_file : "err");
|
| 55 |
+
if (level_user >= level) {
|
| 56 |
+
(*streamptr) << msg <<std::endl;
|
| 57 |
+
}
|
| 58 |
+
}
|
| 59 |
+
#define TUNABLE_LOGV(LEVEL, ...) TunableLog(LEVEL, c10::str(__VA_ARGS__))
|
| 60 |
+
#define TUNABLE_LOG1(...) TUNABLE_LOGV(1, __VA_ARGS__)
|
| 61 |
+
#define TUNABLE_LOG2(...) TUNABLE_LOGV(2, __VA_ARGS__)
|
| 62 |
+
#define TUNABLE_LOG3(...) TUNABLE_LOGV(3, __VA_ARGS__)
|
| 63 |
+
|
| 64 |
+
enum TORCH_CUDA_CPP_API TuningStatus {
|
| 65 |
+
OK = 0,
|
| 66 |
+
FAIL = 1,
|
| 67 |
+
UNSUPPORTED = 2,
|
| 68 |
+
};
|
| 69 |
+
|
| 70 |
+
// Mapping from params signature to kernel id
|
| 71 |
+
class TORCH_CUDA_CPP_API ResultEntry {
|
| 72 |
+
public:
|
| 73 |
+
explicit ResultEntry(const std::string& key, double time) : key_(key), time_(time) {}
|
| 74 |
+
bool operator==(const ResultEntry& other) { return key_ == other.key_; }
|
| 75 |
+
bool operator!=(const ResultEntry& other) { return key_ != other.key_; }
|
| 76 |
+
operator std::string () { return key_; }
|
| 77 |
+
std::string GetKey() const { return key_; }
|
| 78 |
+
double GetTime() const { return time_; }
|
| 79 |
+
friend std::ostream& operator<<(std::ostream& stream, const ResultEntry& entry);
|
| 80 |
+
static ResultEntry Null() { return ResultEntry("Null", 0.0); }
|
| 81 |
+
static ResultEntry Default() { return ResultEntry("Default", 0.0); }
|
| 82 |
+
|
| 83 |
+
private:
|
| 84 |
+
std::string key_;
|
| 85 |
+
double time_;
|
| 86 |
+
};
|
| 87 |
+
|
| 88 |
+
typedef std::unordered_map<std::string, ResultEntry> KernelMap;
|
| 89 |
+
typedef std::unordered_map<std::string, KernelMap> ResultsMap;
|
| 90 |
+
|
| 91 |
+
struct TORCH_CUDA_CPP_API TuningResults {
|
| 92 |
+
// Validates if these results are compatible with the libraries
|
| 93 |
+
std::unordered_map<std::string, std::string> validators;
|
| 94 |
+
|
| 95 |
+
// Mapping from Callable signature to Callable's tuning result
|
| 96 |
+
ResultsMap results;
|
| 97 |
+
};
|
| 98 |
+
|
| 99 |
+
class TORCH_CUDA_CPP_API TuningResultsManager {
|
| 100 |
+
public:
|
| 101 |
+
TuningResultsManager() = default;
|
| 102 |
+
~TuningResultsManager() = default;
|
| 103 |
+
|
| 104 |
+
KernelMap Lookup(const std::string& op_signature);
|
| 105 |
+
|
| 106 |
+
ResultEntry Lookup(const std::string& op_signature, const std::string& params_signature);
|
| 107 |
+
|
| 108 |
+
inline void AddImpl(const std::string& op_signature,
|
| 109 |
+
const std::string& params_signature,
|
| 110 |
+
ResultEntry best,
|
| 111 |
+
KernelMap& kernel_map);
|
| 112 |
+
|
| 113 |
+
void Add(const std::string& op_signature,
|
| 114 |
+
const std::string& params_signature,
|
| 115 |
+
ResultEntry best);
|
| 116 |
+
|
| 117 |
+
void Delete(const std::string& op_signature, const std::string& params_signature);
|
| 118 |
+
|
| 119 |
+
inline void DisjointMergeImpl(
|
| 120 |
+
const std::string& op_signature,
|
| 121 |
+
const KernelMap& kernel_map,
|
| 122 |
+
/*out*/ ResultsMap& results);
|
| 123 |
+
|
| 124 |
+
void Load(const ResultsMap& results_to_load);
|
| 125 |
+
|
| 126 |
+
ResultsMap Dump();
|
| 127 |
+
|
| 128 |
+
void DisjointMerge(const std::string& op_signature, const KernelMap& kernel_map);
|
| 129 |
+
|
| 130 |
+
size_t GetSize();
|
| 131 |
+
|
| 132 |
+
private:
|
| 133 |
+
std::mutex lock_;
|
| 134 |
+
ResultsMap results_;
|
| 135 |
+
};
|
| 136 |
+
|
| 137 |
+
class TORCH_CUDA_CPP_API TuningResultsValidator {
|
| 138 |
+
public:
|
| 139 |
+
using GetFunc = std::function<std::string()>;
|
| 140 |
+
using ValidateFunc = std::function<TuningStatus(const std::string&)>;
|
| 141 |
+
using GetValidateFuncs = std::unordered_map<std::string, std::pair<GetFunc, ValidateFunc>>;
|
| 142 |
+
|
| 143 |
+
TuningResultsValidator();
|
| 144 |
+
~TuningResultsValidator() = default;
|
| 145 |
+
|
| 146 |
+
std::unordered_map<std::string, std::string> GetAllValidators() const;
|
| 147 |
+
TuningStatus ValidateAll(const std::unordered_map<std::string, std::string>& to_validate) const;
|
| 148 |
+
void RegisterValidator(const std::string& key, const GetFunc& gf, const ValidateFunc& vf);
|
| 149 |
+
|
| 150 |
+
protected:
|
| 151 |
+
std::string GetPyTorchVersion() const;
|
| 152 |
+
TuningStatus ValidatePyTorchVersion(const std::string& value) const;
|
| 153 |
+
|
| 154 |
+
public:
|
| 155 |
+
static constexpr const std::array mandatory_keys{"PT_VERSION"};
|
| 156 |
+
|
| 157 |
+
private:
|
| 158 |
+
GetValidateFuncs validators_;
|
| 159 |
+
};
|
| 160 |
+
|
| 161 |
+
class TORCH_CUDA_CPP_API TuningContext {
|
| 162 |
+
public:
|
| 163 |
+
TuningContext();
|
| 164 |
+
~TuningContext();
|
| 165 |
+
TuningContext(TuningContext &) = delete;
|
| 166 |
+
TuningContext(TuningContext &&) = delete;
|
| 167 |
+
TuningContext &operator=(TuningContext &) = delete;
|
| 168 |
+
TuningContext &operator=(TuningContext &&) = delete;
|
| 169 |
+
|
| 170 |
+
void EnableTunableOp(bool value);
|
| 171 |
+
bool IsTunableOpEnabled() const;
|
| 172 |
+
|
| 173 |
+
void EnableTuning(bool value);
|
| 174 |
+
bool IsTuningEnabled() const;
|
| 175 |
+
|
| 176 |
+
void EnableNumericsCheck(bool value);
|
| 177 |
+
bool IsNumericsCheckEnabled() const;
|
| 178 |
+
|
| 179 |
+
void SetMaxTuningDurationMs(int max_duration_ms);
|
| 180 |
+
int GetMaxTuningDurationMs() const;
|
| 181 |
+
|
| 182 |
+
void SetMaxTuningIterations(int max_iter);
|
| 183 |
+
int GetMaxTuningIterations() const;
|
| 184 |
+
|
| 185 |
+
void SetMaxWarmupDurationMs(int max_duration_ms);
|
| 186 |
+
int GetMaxWarmupDurationMs() const;
|
| 187 |
+
|
| 188 |
+
void SetMaxWarmupIterations(int max_iter);
|
| 189 |
+
int GetMaxWarmupIterations() const;
|
| 190 |
+
|
| 191 |
+
void EnableICacheFlush(bool value);
|
| 192 |
+
bool IsICacheFlushEnabled() const;
|
| 193 |
+
|
| 194 |
+
void SetRotatingBufferSize(int size);
|
| 195 |
+
int GetRotatingBufferSize() const;
|
| 196 |
+
|
| 197 |
+
TuningResultsManager& GetTuningResultsManager();
|
| 198 |
+
|
| 199 |
+
TuningResultsValidator& GetTuningResultsValidator();
|
| 200 |
+
|
| 201 |
+
TuningResults GetTuningResults();
|
| 202 |
+
|
| 203 |
+
TuningStatus LoadTuningResults(const TuningResults& tr);
|
| 204 |
+
|
| 205 |
+
void SetFilename(const std::string& filename, bool insert_device_ordinal=false);
|
| 206 |
+
std::string GetFilename() const;
|
| 207 |
+
|
| 208 |
+
void WriteFileOnExit(bool value);
|
| 209 |
+
|
| 210 |
+
bool ReadFile(const std::string& filename={});
|
| 211 |
+
bool WriteFile(const std::string& filename={});
|
| 212 |
+
|
| 213 |
+
private:
|
| 214 |
+
bool enable_;
|
| 215 |
+
bool tuning_enable_;
|
| 216 |
+
bool manager_initialized_;
|
| 217 |
+
bool write_file_on_exit_;
|
| 218 |
+
bool numerics_check_enable_;
|
| 219 |
+
int max_tuning_duration_ms_;
|
| 220 |
+
int max_tuning_iterations_;
|
| 221 |
+
int max_warmup_duration_ms_;
|
| 222 |
+
int max_warmup_iterations_;
|
| 223 |
+
bool icache_flush_;
|
| 224 |
+
int rotating_buffer_size_;
|
| 225 |
+
mutable TuningResultsManager manager_;
|
| 226 |
+
mutable c10::once_flag manager_init_once_;
|
| 227 |
+
TuningResultsValidator validator_;
|
| 228 |
+
std::string filename_;
|
| 229 |
+
size_t results_count_from_input_file_;
|
| 230 |
+
};
|
| 231 |
+
|
| 232 |
+
TORCH_CUDA_CPP_API TuningContext* getTuningContext();
|
| 233 |
+
|
| 234 |
+
class ITimer {
|
| 235 |
+
public:
|
| 236 |
+
ITimer() = default;
|
| 237 |
+
virtual ~ITimer() = default;
|
| 238 |
+
|
| 239 |
+
virtual void Start() = 0;
|
| 240 |
+
virtual void End() = 0;
|
| 241 |
+
|
| 242 |
+
/// Computes the elapsed time in milliseconds between Start() and End()
|
| 243 |
+
virtual float Duration() = 0;
|
| 244 |
+
};
|
| 245 |
+
|
| 246 |
+
} // namespace at::cuda::tunable
|
.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/tunable/TunableGemm.h
ADDED
|
@@ -0,0 +1,307 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Original TunableOp is from onnxruntime.
|
| 2 |
+
// https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/core/framework/tunable.h
|
| 3 |
+
// https://github.com/microsoft/onnxruntime/tree/main/onnxruntime/core/providers/rocm/tunable
|
| 4 |
+
// Copyright (c) Microsoft Corporation.
|
| 5 |
+
// Licensed under the MIT license.
|
| 6 |
+
//
|
| 7 |
+
// Adapting TunableOp into PyTorch
|
| 8 |
+
// Copyright (c) Advanced Micro Devices, Inc.
|
| 9 |
+
//
|
| 10 |
+
#pragma once
|
| 11 |
+
|
| 12 |
+
#include <ATen/cuda/tunable/GemmCommon.h>
|
| 13 |
+
#ifdef USE_ROCM
|
| 14 |
+
#include <ATen/cuda/tunable/GemmHipblaslt.h>
|
| 15 |
+
#include <ATen/cuda/tunable/GemmRocblas.h>
|
| 16 |
+
#endif
|
| 17 |
+
#include <ATen/cuda/tunable/StreamTimer.h>
|
| 18 |
+
#include <ATen/cuda/tunable/TunableOp.h>
|
| 19 |
+
#include <c10/cuda/CUDACachingAllocator.h>
|
| 20 |
+
#include <c10/util/Float8_e4m3fn.h>
|
| 21 |
+
#include <c10/util/Float8_e4m3fnuz.h>
|
| 22 |
+
#include <c10/util/Float8_e5m2.h>
|
| 23 |
+
#include <c10/util/Float8_e5m2fnuz.h>
|
| 24 |
+
#include <c10/util/StringUtil.h>
|
| 25 |
+
|
| 26 |
+
namespace at::cuda::tunable {
|
| 27 |
+
|
| 28 |
+
template <typename T>
|
| 29 |
+
class DefaultGemmOp : public Callable<GemmParams<T>> {
|
| 30 |
+
public:
|
| 31 |
+
TuningStatus Call(const GemmParams<T>* params) override {
|
| 32 |
+
at::cuda::blas::gemm_internal<T>(
|
| 33 |
+
params->transa, params->transb,
|
| 34 |
+
params->m, params->n, params->k,
|
| 35 |
+
params->alpha,
|
| 36 |
+
params->a, params->lda,
|
| 37 |
+
params->b, params->ldb,
|
| 38 |
+
params->beta,
|
| 39 |
+
params->c, params->ldc);
|
| 40 |
+
return OK;
|
| 41 |
+
}
|
| 42 |
+
};
|
| 43 |
+
|
| 44 |
+
static bool _transposeBoolFromChar(char op) {
|
| 45 |
+
return op == 't' || op == 'T';
|
| 46 |
+
}
|
| 47 |
+
|
| 48 |
+
template <typename T>
|
| 49 |
+
class DefaultGemmAndBiasOp : public Callable<GemmAndBiasParams<T>> {
|
| 50 |
+
public:
|
| 51 |
+
TuningStatus Call(const GemmAndBiasParams<T>* params) override {
|
| 52 |
+
at::cuda::blas::gemm_and_bias<T>(
|
| 53 |
+
_transposeBoolFromChar(params->transa),
|
| 54 |
+
_transposeBoolFromChar(params->transb),
|
| 55 |
+
params->m, params->n, params->k,
|
| 56 |
+
params->alpha,
|
| 57 |
+
params->a, params->lda,
|
| 58 |
+
params->b, params->ldb,
|
| 59 |
+
params->bias,
|
| 60 |
+
params->c, params->ldc,
|
| 61 |
+
params->activation);
|
| 62 |
+
return OK;
|
| 63 |
+
}
|
| 64 |
+
};
|
| 65 |
+
|
| 66 |
+
template <typename T>
|
| 67 |
+
class DefaultGemmStridedBatchedOp : public Callable<GemmStridedBatchedParams<T>> {
|
| 68 |
+
public:
|
| 69 |
+
TuningStatus Call(const GemmStridedBatchedParams<T>* params) override {
|
| 70 |
+
at::cuda::blas::bgemm_internal<T>(
|
| 71 |
+
params->transa, params->transb,
|
| 72 |
+
params->m, params->n, params->k,
|
| 73 |
+
params->alpha,
|
| 74 |
+
params->a, params->lda, params->stride_a,
|
| 75 |
+
params->b, params->ldb, params->stride_b,
|
| 76 |
+
params->beta,
|
| 77 |
+
params->c, params->ldc, params->stride_c,
|
| 78 |
+
params->batch);
|
| 79 |
+
return OK;
|
| 80 |
+
}
|
| 81 |
+
};
|
| 82 |
+
|
| 83 |
+
template <typename T>
|
| 84 |
+
class DefaultScaledGemmOp : public Callable<ScaledGemmParams<T>> {
|
| 85 |
+
public:
|
| 86 |
+
TuningStatus Call(const ScaledGemmParams<T>* params) override {
|
| 87 |
+
at::cuda::blas::scaled_gemm(
|
| 88 |
+
params->transa,
|
| 89 |
+
params->transb,
|
| 90 |
+
params->m,
|
| 91 |
+
params->n,
|
| 92 |
+
params->k,
|
| 93 |
+
params->a,
|
| 94 |
+
params->a_scale_ptr,
|
| 95 |
+
params->lda,
|
| 96 |
+
params->a_dtype,
|
| 97 |
+
params->b,
|
| 98 |
+
params->b_scale_ptr,
|
| 99 |
+
params->ldb,
|
| 100 |
+
params->b_dtype,
|
| 101 |
+
params->bias_ptr,
|
| 102 |
+
params->bias_dtype,
|
| 103 |
+
params->c,
|
| 104 |
+
params->c_scale_ptr,
|
| 105 |
+
params->ldc,
|
| 106 |
+
params->c_dtype,
|
| 107 |
+
params->amax_ptr,
|
| 108 |
+
params->use_fast_accum);
|
| 109 |
+
return OK;
|
| 110 |
+
}
|
| 111 |
+
};
|
| 112 |
+
|
| 113 |
+
template <typename T>
|
| 114 |
+
inline bool IsZero(T v) {
|
| 115 |
+
return v == 0.0f;
|
| 116 |
+
}
|
| 117 |
+
|
| 118 |
+
template <>
|
| 119 |
+
inline bool IsZero(BFloat16 v) {
|
| 120 |
+
return v.x == 0;
|
| 121 |
+
}
|
| 122 |
+
|
| 123 |
+
template <>
|
| 124 |
+
inline bool IsZero(Half v) {
|
| 125 |
+
return float(v) == 0.0f;
|
| 126 |
+
}
|
| 127 |
+
|
| 128 |
+
template <>
|
| 129 |
+
inline bool IsZero(c10::complex<double> v) {
|
| 130 |
+
return v == 0.0;
|
| 131 |
+
}
|
| 132 |
+
|
| 133 |
+
template <>
|
| 134 |
+
inline bool IsZero(c10::complex<float> v) {
|
| 135 |
+
return v == 0.0f;
|
| 136 |
+
}
|
| 137 |
+
|
| 138 |
+
template <typename T>
|
| 139 |
+
inline std::string TypeName(T v) {
|
| 140 |
+
return "unknown";
|
| 141 |
+
}
|
| 142 |
+
|
| 143 |
+
template <>
|
| 144 |
+
inline std::string TypeName(float v) {
|
| 145 |
+
return "float";
|
| 146 |
+
}
|
| 147 |
+
|
| 148 |
+
template <>
|
| 149 |
+
inline std::string TypeName(double v) {
|
| 150 |
+
return "double";
|
| 151 |
+
}
|
| 152 |
+
|
| 153 |
+
template <>
|
| 154 |
+
inline std::string TypeName(BFloat16 v) {
|
| 155 |
+
return "BFloat16";
|
| 156 |
+
}
|
| 157 |
+
|
| 158 |
+
template <>
|
| 159 |
+
inline std::string TypeName(Half v) {
|
| 160 |
+
return "Half";
|
| 161 |
+
}
|
| 162 |
+
|
| 163 |
+
template <>
|
| 164 |
+
inline std::string TypeName(Float8_e4m3fn v) {
|
| 165 |
+
return "Float8_e4m3fn";
|
| 166 |
+
}
|
| 167 |
+
|
| 168 |
+
template <>
|
| 169 |
+
inline std::string TypeName(Float8_e5m2 v) {
|
| 170 |
+
return "Float8_e5m2";
|
| 171 |
+
}
|
| 172 |
+
|
| 173 |
+
template <>
|
| 174 |
+
inline std::string TypeName(Float8_e4m3fnuz v) {
|
| 175 |
+
return "Float8_e4m3fnuz";
|
| 176 |
+
}
|
| 177 |
+
|
| 178 |
+
template <>
|
| 179 |
+
inline std::string TypeName(Float8_e5m2fnuz v) {
|
| 180 |
+
return "Float8_e5m2fnuz";
|
| 181 |
+
}
|
| 182 |
+
|
| 183 |
+
template <>
|
| 184 |
+
inline std::string TypeName(c10::complex<double> v) {
|
| 185 |
+
return "c10::complex<double>";
|
| 186 |
+
}
|
| 187 |
+
|
| 188 |
+
template <>
|
| 189 |
+
inline std::string TypeName(c10::complex<float> v) {
|
| 190 |
+
return "c10::complex<float>";
|
| 191 |
+
}
|
| 192 |
+
|
| 193 |
+
template <typename T, BlasOp ALayout, BlasOp BLayout>
|
| 194 |
+
class GemmTunableOp : public TunableOp<GemmParams<T>, StreamTimer> {
|
| 195 |
+
public:
|
| 196 |
+
GemmTunableOp() {
|
| 197 |
+
this->RegisterOp(std::string("Default"), std::make_unique<DefaultGemmOp<T>>());
|
| 198 |
+
|
| 199 |
+
#ifdef USE_ROCM
|
| 200 |
+
static const char *env_rocblas = std::getenv("PYTORCH_TUNABLEOP_ROCBLAS_ENABLED");
|
| 201 |
+
if (env_rocblas == nullptr || strcmp(env_rocblas, "1") == 0) {
|
| 202 |
+
for (auto&& [name, op] : GetRocBlasGemmTypeStringAndOps<T>()) {
|
| 203 |
+
this->RegisterOp(std::move(name), std::move(op));
|
| 204 |
+
}
|
| 205 |
+
}
|
| 206 |
+
|
| 207 |
+
static const char *env_hipblaslt = std::getenv("PYTORCH_TUNABLEOP_HIPBLASLT_ENABLED");
|
| 208 |
+
if (env_hipblaslt == nullptr || strcmp(env_hipblaslt, "1") == 0) {
|
| 209 |
+
// disallow tuning of hipblaslt with c10::complex
|
| 210 |
+
if constexpr (
|
| 211 |
+
!std::is_same_v<T, c10::complex<float>> &&
|
| 212 |
+
!std::is_same_v<T, c10::complex<double>>) {
|
| 213 |
+
for (auto&& [name, op] : GetHipBlasLtGemmTypeStringAndOps<T, ALayout, BLayout>()) {
|
| 214 |
+
this->RegisterOp(std::move(name), std::move(op));
|
| 215 |
+
}
|
| 216 |
+
}
|
| 217 |
+
}
|
| 218 |
+
#endif
|
| 219 |
+
}
|
| 220 |
+
|
| 221 |
+
std::string Signature() override {
|
| 222 |
+
return c10::str("GemmTunableOp_", TypeName<T>(T{}), "_", BlasOpToString(ALayout), BlasOpToString(BLayout));
|
| 223 |
+
}
|
| 224 |
+
};
|
| 225 |
+
|
| 226 |
+
template <typename T, BlasOp ALayout, BlasOp BLayout>
|
| 227 |
+
class GemmAndBiasTunableOp : public TunableOp<GemmAndBiasParams<T>, StreamTimer> {
|
| 228 |
+
public:
|
| 229 |
+
GemmAndBiasTunableOp() {
|
| 230 |
+
this->RegisterOp(std::string("Default"), std::make_unique<DefaultGemmAndBiasOp<T>>());
|
| 231 |
+
|
| 232 |
+
#ifdef USE_ROCM
|
| 233 |
+
static const char *env_hipblaslt = std::getenv("PYTORCH_TUNABLEOP_HIPBLASLT_ENABLED");
|
| 234 |
+
if (env_hipblaslt == nullptr || strcmp(env_hipblaslt, "1") == 0) {
|
| 235 |
+
// disallow tuning of hipblaslt with c10::complex
|
| 236 |
+
if constexpr (
|
| 237 |
+
!std::is_same_v<T, c10::complex<float>> &&
|
| 238 |
+
!std::is_same_v<T, c10::complex<double>>) {
|
| 239 |
+
for (auto&& [name, op] : GetHipBlasLtGemmAndBiasTypeStringAndOps<T, ALayout, BLayout>()) {
|
| 240 |
+
this->RegisterOp(std::move(name), std::move(op));
|
| 241 |
+
}
|
| 242 |
+
}
|
| 243 |
+
}
|
| 244 |
+
#endif
|
| 245 |
+
}
|
| 246 |
+
|
| 247 |
+
std::string Signature() override {
|
| 248 |
+
return c10::str("GemmAndBiasTunableOp_", TypeName<T>(T{}), "_", BlasOpToString(ALayout), BlasOpToString(BLayout));
|
| 249 |
+
}
|
| 250 |
+
};
|
| 251 |
+
|
| 252 |
+
template <typename T, BlasOp ALayout, BlasOp BLayout>
|
| 253 |
+
class GemmStridedBatchedTunableOp : public TunableOp<GemmStridedBatchedParams<T>, StreamTimer> {
|
| 254 |
+
public:
|
| 255 |
+
GemmStridedBatchedTunableOp() {
|
| 256 |
+
this->RegisterOp(std::string("Default"), std::make_unique<DefaultGemmStridedBatchedOp<T>>());
|
| 257 |
+
|
| 258 |
+
#ifdef USE_ROCM
|
| 259 |
+
static const char *env_rocblas = std::getenv("PYTORCH_TUNABLEOP_ROCBLAS_ENABLED");
|
| 260 |
+
if (env_rocblas == nullptr || strcmp(env_rocblas, "1") == 0) {
|
| 261 |
+
for (auto&& [name, op] : GetRocBlasGemmStridedBatchedTypeStringAndOps<T>()) {
|
| 262 |
+
this->RegisterOp(std::move(name), std::move(op));
|
| 263 |
+
}
|
| 264 |
+
}
|
| 265 |
+
|
| 266 |
+
static const char *env_hipblaslt = std::getenv("PYTORCH_TUNABLEOP_HIPBLASLT_ENABLED");
|
| 267 |
+
if (env_hipblaslt == nullptr || strcmp(env_hipblaslt, "1") == 0) {
|
| 268 |
+
// disallow tuning of hipblaslt with c10::complex
|
| 269 |
+
if constexpr (
|
| 270 |
+
!std::is_same_v<T, c10::complex<float>> &&
|
| 271 |
+
!std::is_same_v<T, c10::complex<double>>) {
|
| 272 |
+
for (auto&& [name, op] : GetHipBlasLtGemmStridedBatchedTypeStringAndOps<T, ALayout, BLayout>()) {
|
| 273 |
+
this->RegisterOp(std::move(name), std::move(op));
|
| 274 |
+
}
|
| 275 |
+
}
|
| 276 |
+
}
|
| 277 |
+
#endif
|
| 278 |
+
}
|
| 279 |
+
|
| 280 |
+
std::string Signature() override {
|
| 281 |
+
return c10::str("GemmStridedBatchedTunableOp_", TypeName<T>(T{}), "_", BlasOpToString(ALayout), BlasOpToString(BLayout));
|
| 282 |
+
}
|
| 283 |
+
};
|
| 284 |
+
|
| 285 |
+
template <typename AT, typename BT, typename CT, BlasOp ALayout, BlasOp BLayout>
|
| 286 |
+
class ScaledGemmTunableOp : public TunableOp<ScaledGemmParams<CT>, StreamTimer> {
|
| 287 |
+
public:
|
| 288 |
+
ScaledGemmTunableOp() {
|
| 289 |
+
this->RegisterOp(std::string("Default"), std::make_unique<DefaultScaledGemmOp<CT>>());
|
| 290 |
+
|
| 291 |
+
#ifdef USE_ROCM
|
| 292 |
+
for (auto&& [name, op] : GetHipBlasLtScaledGemmTypeStringAndOps<AT, BT, CT, ALayout, BLayout>()) {
|
| 293 |
+
this->RegisterOp(std::move(name), std::move(op));
|
| 294 |
+
}
|
| 295 |
+
#endif
|
| 296 |
+
}
|
| 297 |
+
|
| 298 |
+
std::string Signature() override {
|
| 299 |
+
return c10::str("ScaledGemmTunableOp",
|
| 300 |
+
"_", TypeName<AT>(AT{}),
|
| 301 |
+
"_", TypeName<BT>(BT{}),
|
| 302 |
+
"_", TypeName<CT>(CT{}),
|
| 303 |
+
"_", BlasOpToString(ALayout), BlasOpToString(BLayout));
|
| 304 |
+
}
|
| 305 |
+
};
|
| 306 |
+
|
| 307 |
+
} // namespace at::cuda::tunable
|
.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/tunable/TunableOp.h
ADDED
|
@@ -0,0 +1,286 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Original TunableOp is from onnxruntime.
|
| 2 |
+
// https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/core/framework/tunable.h
|
| 3 |
+
// https://github.com/microsoft/onnxruntime/tree/main/onnxruntime/core/providers/rocm/tunable
|
| 4 |
+
// Copyright (c) Microsoft Corporation.
|
| 5 |
+
// Licensed under the MIT license.
|
| 6 |
+
//
|
| 7 |
+
// Adapting TunableOp into PyTorch
|
| 8 |
+
// Copyright (c) Advanced Micro Devices, Inc.
|
| 9 |
+
//
|
| 10 |
+
#pragma once
|
| 11 |
+
|
| 12 |
+
#include <ATen/cuda/tunable/Tunable.h>
|
| 13 |
+
#include <ATen/cuda/Sleep.h>
|
| 14 |
+
#include <c10/cuda/CUDACachingAllocator.h>
|
| 15 |
+
|
| 16 |
+
#ifndef _WIN32
|
| 17 |
+
#include <cxxabi.h>
|
| 18 |
+
#endif
|
| 19 |
+
|
| 20 |
+
#include <string>
|
| 21 |
+
#include <type_traits>
|
| 22 |
+
#include <unordered_map>
|
| 23 |
+
#include <vector>
|
| 24 |
+
|
| 25 |
+
namespace at::cuda::tunable {
|
| 26 |
+
|
| 27 |
+
template <typename ParamsT>
|
| 28 |
+
class Callable {
|
| 29 |
+
public:
|
| 30 |
+
Callable() = default;
|
| 31 |
+
Callable(Callable&&) = default;
|
| 32 |
+
virtual ~Callable() = default;
|
| 33 |
+
virtual TuningStatus Call(const ParamsT*) {
|
| 34 |
+
return FAIL;
|
| 35 |
+
}
|
| 36 |
+
virtual TuningStatus IsSupported(const ParamsT* params) {
|
| 37 |
+
return Call(params);
|
| 38 |
+
}
|
| 39 |
+
};
|
| 40 |
+
|
| 41 |
+
template <typename ParamsT, typename TimerT>
|
| 42 |
+
class TunableOp {
|
| 43 |
+
public:
|
| 44 |
+
TunableOp() = default;
|
| 45 |
+
TunableOp(TunableOp&&) = default;
|
| 46 |
+
virtual ~TunableOp() = default;
|
| 47 |
+
|
| 48 |
+
TuningStatus operator()(const ParamsT* params) {
|
| 49 |
+
ResultEntry result = ResultEntry::Null();
|
| 50 |
+
TuningContext* ctx = getTuningContext();
|
| 51 |
+
if (ctx->IsTunableOpEnabled()) {
|
| 52 |
+
auto& mgr = ctx->GetTuningResultsManager();
|
| 53 |
+
auto op_sig = Signature();
|
| 54 |
+
auto params_sig = params->Signature();
|
| 55 |
+
result = mgr.Lookup(op_sig, params_sig);
|
| 56 |
+
// If there is not previous tuning result been found, we do the tuning iff tuning is enabled
|
| 57 |
+
if (result == ResultEntry::Null() && ctx->IsTuningEnabled()) {
|
| 58 |
+
result = FindFastest(params);
|
| 59 |
+
mgr.Add(op_sig, params_sig, result);
|
| 60 |
+
}
|
| 61 |
+
}
|
| 62 |
+
else {
|
| 63 |
+
result = ResultEntry::Default();
|
| 64 |
+
}
|
| 65 |
+
if (result == ResultEntry::Null()) {
|
| 66 |
+
TUNABLE_LOG2("no result, using default");
|
| 67 |
+
result = ResultEntry::Default();
|
| 68 |
+
}
|
| 69 |
+
auto iter = ops_.find(result);
|
| 70 |
+
TORCH_CHECK(iter != ops_.end());
|
| 71 |
+
return iter->second->Call(params);
|
| 72 |
+
}
|
| 73 |
+
|
| 74 |
+
virtual std::string Signature() {
|
| 75 |
+
// According to C++17 standard https://wg21.link/n4659 section 15.7.4
|
| 76 |
+
// > if the operand of typeid refers to the
|
| 77 |
+
// > object under construction or destruction, typeid yields the std::type_info object representing the constructor
|
| 78 |
+
// > or destructor’s class.
|
| 79 |
+
// So delay the op signature generation.
|
| 80 |
+
c10::call_once(signature_init_once_, [this]() { signature_ = CreateSignature(); });
|
| 81 |
+
return signature_;
|
| 82 |
+
}
|
| 83 |
+
|
| 84 |
+
protected:
|
| 85 |
+
void RegisterOp(const std::string& name, std::unique_ptr<Callable<ParamsT>> op) {
|
| 86 |
+
this->op_names_.emplace_back(name);
|
| 87 |
+
this->ops_.emplace(name, std::move(op));
|
| 88 |
+
}
|
| 89 |
+
|
| 90 |
+
private:
|
| 91 |
+
static void WarmUp(Callable<ParamsT> *op, const std::vector<ParamsT*> ¶m, size_t num_iter, size_t &offset) {
|
| 92 |
+
TuningContext* ctx = getTuningContext();
|
| 93 |
+
bool do_flush = ctx->IsICacheFlushEnabled();
|
| 94 |
+
for (size_t i = 0; i < num_iter; i++) {
|
| 95 |
+
if (do_flush) {
|
| 96 |
+
at::cuda::flush_icache();
|
| 97 |
+
}
|
| 98 |
+
TORCH_CHECK(op->Call(param[(i+offset++)%param.size()]) == OK);
|
| 99 |
+
}
|
| 100 |
+
}
|
| 101 |
+
|
| 102 |
+
static double Profile(Callable<ParamsT> *op, const std::vector<ParamsT*> ¶m, size_t num_iter, size_t &offset) {
|
| 103 |
+
TuningContext* ctx = getTuningContext();
|
| 104 |
+
bool do_flush = ctx->IsICacheFlushEnabled();
|
| 105 |
+
TimerT timer{};
|
| 106 |
+
timer.Start();
|
| 107 |
+
for (size_t i = 0; i < num_iter; i++) {
|
| 108 |
+
if (do_flush) {
|
| 109 |
+
at::cuda::flush_icache();
|
| 110 |
+
}
|
| 111 |
+
TORCH_CHECK(op->Call(param[(i+offset++)%param.size()]) == OK);
|
| 112 |
+
}
|
| 113 |
+
timer.End();
|
| 114 |
+
return timer.Duration() / num_iter;
|
| 115 |
+
}
|
| 116 |
+
|
| 117 |
+
protected:
|
| 118 |
+
virtual ResultEntry FindFastest(const ParamsT* params) {
|
| 119 |
+
TuningContext* ctx = getTuningContext();
|
| 120 |
+
auto op_sig = Signature();
|
| 121 |
+
auto params_sig = params->Signature();
|
| 122 |
+
TUNABLE_LOG2("finding fastest for ", op_sig, '(', params_sig, ')', " out of ", op_names_.size(), " candidates");
|
| 123 |
+
auto min_duration_ms = std::numeric_limits<double>::infinity();
|
| 124 |
+
std::string id_name = "Default";
|
| 125 |
+
ParamsT* reference_params = nullptr;
|
| 126 |
+
|
| 127 |
+
// numeric check option is controlled by non-static env var, so check it once per tuned operator
|
| 128 |
+
bool do_numerics_check = ctx->IsNumericsCheckEnabled();
|
| 129 |
+
|
| 130 |
+
// calcaulte a reference answer for numerical check
|
| 131 |
+
if (do_numerics_check) {
|
| 132 |
+
reference_params = params->DeepCopy(false);
|
| 133 |
+
TORCH_CHECK(ops_[ResultEntry::Default()]->Call(reference_params) == OK);
|
| 134 |
+
}
|
| 135 |
+
|
| 136 |
+
// need copies of params to reuse
|
| 137 |
+
// make as many copies as will fill the requested rotating buffer size, if requested
|
| 138 |
+
// rotating_size guaranteed to be >= 0 even though GetRotatingBufferSize() returns int
|
| 139 |
+
size_t rotating_size = ctx->GetRotatingBufferSize();
|
| 140 |
+
bool use_buffer_rotation = (rotating_size > 0);
|
| 141 |
+
size_t param_size = params->GetSize(use_buffer_rotation);
|
| 142 |
+
size_t param_count = (rotating_size / param_size) + 1;
|
| 143 |
+
constexpr size_t MB = 1024*1024;
|
| 144 |
+
if (use_buffer_rotation) {
|
| 145 |
+
TUNABLE_LOG2("Rotating buffer ", rotating_size/MB, " MiB. ",
|
| 146 |
+
"Needed Size: ", param_size/MB, " MiB. ",
|
| 147 |
+
"Needed number of param copies: ", param_count);
|
| 148 |
+
}
|
| 149 |
+
TORCH_CHECK(param_count > 0);
|
| 150 |
+
|
| 151 |
+
std::vector<ParamsT*> reusable_params(param_count);
|
| 152 |
+
for (size_t i = 0; i < param_count; i++) {
|
| 153 |
+
reusable_params[i] = params->DeepCopy(use_buffer_rotation);
|
| 154 |
+
}
|
| 155 |
+
|
| 156 |
+
// for rotating buffer
|
| 157 |
+
size_t offset = 0;
|
| 158 |
+
|
| 159 |
+
for (size_t i = 0; i < op_names_.size(); i++) {
|
| 160 |
+
auto* candidate = ops_[op_names_[i]].get(); // borrow pointer
|
| 161 |
+
|
| 162 |
+
if (do_numerics_check) {
|
| 163 |
+
ParamsT* numerical_params = params->DeepCopy(false);
|
| 164 |
+
auto status = candidate->Call(numerical_params);
|
| 165 |
+
if (status != OK) {
|
| 166 |
+
numerical_params->Delete();
|
| 167 |
+
TUNABLE_LOG3("├──unsupported id=", i, ", ", op_sig, '(', params_sig, ") ", op_names_[i]);
|
| 168 |
+
continue;
|
| 169 |
+
}
|
| 170 |
+
status = reference_params->NumericalCheck(numerical_params);
|
| 171 |
+
numerical_params->Delete();
|
| 172 |
+
if (status != OK) {
|
| 173 |
+
TUNABLE_LOG3("├──numerics check failed for id=", i, ", ", op_sig, '(', params_sig, ") ", op_names_[i]);
|
| 174 |
+
continue;
|
| 175 |
+
}
|
| 176 |
+
}
|
| 177 |
+
else {
|
| 178 |
+
auto status = candidate->Call(reusable_params[0]);
|
| 179 |
+
if (status != OK) {
|
| 180 |
+
TUNABLE_LOG3("├──unsupported id=", i, ", ", op_sig, '(', params_sig, ") ", op_names_[i]);
|
| 181 |
+
continue;
|
| 182 |
+
}
|
| 183 |
+
}
|
| 184 |
+
|
| 185 |
+
// collect a small profile
|
| 186 |
+
constexpr const int approx_num_iter = 3;
|
| 187 |
+
auto approx_duration = Profile(candidate, reusable_params, approx_num_iter, offset);
|
| 188 |
+
// bail if too slow
|
| 189 |
+
if (approx_duration > 2 * min_duration_ms) {
|
| 190 |
+
TUNABLE_LOG3("├──skip slow instance id=", i, ", ", op_sig, '(', params_sig, ") ", op_names_[i]);
|
| 191 |
+
continue;
|
| 192 |
+
}
|
| 193 |
+
|
| 194 |
+
// for warmup does user set max duration, max iters, or both?
|
| 195 |
+
// warmup is allowed to be skipped by setting either iterations or duration to 0
|
| 196 |
+
double max_warmup_duration = ctx->GetMaxWarmupDurationMs();
|
| 197 |
+
int max_warmup_iter = ctx->GetMaxWarmupIterations();
|
| 198 |
+
int warmup_iter = 1; // default
|
| 199 |
+
if (max_warmup_duration >= 0) {
|
| 200 |
+
int duration_iters = max_warmup_duration / approx_duration;
|
| 201 |
+
if (max_warmup_iter >= 0) {
|
| 202 |
+
warmup_iter = std::min(max_warmup_iter, duration_iters);
|
| 203 |
+
}
|
| 204 |
+
else {
|
| 205 |
+
warmup_iter = duration_iters;
|
| 206 |
+
}
|
| 207 |
+
}
|
| 208 |
+
else if (max_warmup_iter >= 0) {
|
| 209 |
+
warmup_iter = max_warmup_iter;
|
| 210 |
+
}
|
| 211 |
+
|
| 212 |
+
// for tuning does user set max duration, max iters, or both?
|
| 213 |
+
double max_tuning_duration = ctx->GetMaxTuningDurationMs();
|
| 214 |
+
int max_tuning_iter = ctx->GetMaxTuningIterations();
|
| 215 |
+
int tuning_iter = 100; // default
|
| 216 |
+
if (max_tuning_duration > 0) {
|
| 217 |
+
int duration_iters = max_tuning_duration / approx_duration;
|
| 218 |
+
if (max_tuning_iter > 0) {
|
| 219 |
+
tuning_iter = std::min(max_tuning_iter, duration_iters);
|
| 220 |
+
}
|
| 221 |
+
else {
|
| 222 |
+
tuning_iter = duration_iters;
|
| 223 |
+
}
|
| 224 |
+
}
|
| 225 |
+
else if (max_tuning_iter > 0) {
|
| 226 |
+
tuning_iter = max_tuning_iter;
|
| 227 |
+
}
|
| 228 |
+
// tuning must run at least 1 iteration
|
| 229 |
+
tuning_iter = std::max(1, tuning_iter);
|
| 230 |
+
|
| 231 |
+
// do the full warmup followed by tuning
|
| 232 |
+
double warmup_ms = warmup_iter * approx_duration;
|
| 233 |
+
double tuning_ms = tuning_iter * approx_duration;
|
| 234 |
+
TUNABLE_LOG3("├──tuning using "
|
| 235 |
+
"warmup iters ", warmup_iter, " [", warmup_ms, " ms] "
|
| 236 |
+
"and tuning iters ", tuning_iter, " [", tuning_ms, " ms] ",
|
| 237 |
+
"instance id=", i, ", ", op_sig, "(", params_sig, ") ", op_names_[i]);
|
| 238 |
+
TUNABLE_LOG3("├──offset at ", offset);
|
| 239 |
+
WarmUp(candidate, reusable_params, warmup_iter, offset);
|
| 240 |
+
auto duration_ms = Profile(candidate, reusable_params, tuning_iter, offset);
|
| 241 |
+
if (duration_ms < min_duration_ms) {
|
| 242 |
+
TUNABLE_LOG3("├──found better instance id=", i, ". " , duration_ms, "ms. ", op_names_[i]);
|
| 243 |
+
min_duration_ms = duration_ms;
|
| 244 |
+
id_name = op_names_[i];
|
| 245 |
+
}
|
| 246 |
+
}
|
| 247 |
+
|
| 248 |
+
for (size_t i = 0; i < reusable_params.size(); i++) {
|
| 249 |
+
reusable_params[i]->Delete();
|
| 250 |
+
}
|
| 251 |
+
if (reference_params) {
|
| 252 |
+
reference_params->Delete();
|
| 253 |
+
}
|
| 254 |
+
|
| 255 |
+
TUNABLE_LOG2("└──found fastest for ", op_sig, '(', params_sig, ") ", id_name);
|
| 256 |
+
return ResultEntry(id_name, min_duration_ms);
|
| 257 |
+
}
|
| 258 |
+
|
| 259 |
+
private:
|
| 260 |
+
std::string CreateSignature() {
|
| 261 |
+
#ifndef _WIN32
|
| 262 |
+
const auto* name = typeid(*this).name();
|
| 263 |
+
char buf[256];
|
| 264 |
+
size_t buf_len = 256;
|
| 265 |
+
abi::__cxa_demangle(name, buf, &buf_len, nullptr);
|
| 266 |
+
buf[255] = '\0';
|
| 267 |
+
return buf;
|
| 268 |
+
#else
|
| 269 |
+
return typeid(*this).name();
|
| 270 |
+
#endif
|
| 271 |
+
}
|
| 272 |
+
|
| 273 |
+
mutable c10::once_flag signature_init_once_;
|
| 274 |
+
std::string signature_;
|
| 275 |
+
|
| 276 |
+
std::unordered_map<std::string, std::unique_ptr<Callable<ParamsT>>> ops_;
|
| 277 |
+
std::vector<std::string> op_names_;
|
| 278 |
+
};
|
| 279 |
+
|
| 280 |
+
struct OpParams {
|
| 281 |
+
OpParams() {}
|
| 282 |
+
virtual ~OpParams() = default;
|
| 283 |
+
virtual std::string Signature() const = 0;
|
| 284 |
+
};
|
| 285 |
+
|
| 286 |
+
} // namespace at::cuda::tunable
|
.venv/lib/python3.11/site-packages/torch/include/ATen/native/Activation.h
ADDED
|
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/native/DispatchStub.h>
|
| 4 |
+
#include <c10/util/Exception.h>
|
| 5 |
+
#include <c10/util/string_view.h>
|
| 6 |
+
|
| 7 |
+
namespace c10 {
|
| 8 |
+
class Scalar;
|
| 9 |
+
}
|
| 10 |
+
|
| 11 |
+
namespace at {
|
| 12 |
+
struct TensorIterator;
|
| 13 |
+
struct TensorIteratorBase;
|
| 14 |
+
class TensorBase;
|
| 15 |
+
}
|
| 16 |
+
|
| 17 |
+
namespace at::native {
|
| 18 |
+
|
| 19 |
+
// These constants control the approximation behavior of gelu function.
|
| 20 |
+
enum class GeluType {
|
| 21 |
+
None, // Baseline Gelu
|
| 22 |
+
Tanh, // Tahn Gelu Approximation
|
| 23 |
+
END
|
| 24 |
+
};
|
| 25 |
+
|
| 26 |
+
inline GeluType get_gelutype_enum(const c10::string_view approximate) {
|
| 27 |
+
if (approximate == "none") {
|
| 28 |
+
return GeluType::None;
|
| 29 |
+
} else if (approximate == "tanh") {
|
| 30 |
+
return GeluType::Tanh;
|
| 31 |
+
} else {
|
| 32 |
+
TORCH_CHECK(false, "approximate argument must be either none or tanh.");
|
| 33 |
+
}
|
| 34 |
+
}
|
| 35 |
+
|
| 36 |
+
inline std::string gelutype_to_string(const GeluType type) {
|
| 37 |
+
switch(type) {
|
| 38 |
+
case GeluType::None: return "none";
|
| 39 |
+
case GeluType::Tanh: return "tanh";
|
| 40 |
+
default: TORCH_CHECK(false, "unknown GELU type: ", static_cast<int>(type));
|
| 41 |
+
}
|
| 42 |
+
}
|
| 43 |
+
|
| 44 |
+
using structured_activation_fn = void (*)(TensorIteratorBase&);
|
| 45 |
+
using structured_activation_backward_fn = void (*)(TensorIteratorBase&);
|
| 46 |
+
|
| 47 |
+
using activation_fn = void (*)(TensorIterator&);
|
| 48 |
+
using activation_backward_fn = void (*)(TensorIterator&);
|
| 49 |
+
using softplus_fn = void (*)(TensorIteratorBase&, const c10::Scalar&, const c10::Scalar&);
|
| 50 |
+
using softplus_backward_fn = void (*)(TensorIteratorBase&, const c10::Scalar&, const c10::Scalar&);
|
| 51 |
+
using threshold_fn = void (*)(TensorIteratorBase&, const c10::Scalar&, const c10::Scalar&);
|
| 52 |
+
using hardtanh_backward_fn = void (*)(TensorIterator&, const c10::Scalar&, const c10::Scalar&);
|
| 53 |
+
using hardsigmoid_fn = void(*)(TensorIteratorBase&);
|
| 54 |
+
using hardsigmoid_backward_fn = void(*)(TensorIteratorBase&);
|
| 55 |
+
using hardswish_fn = void(*)(TensorIterator&);
|
| 56 |
+
using hardswish_backward_fn = void(*)(TensorIterator&);
|
| 57 |
+
using shrink_fn = void (*)(TensorIteratorBase&, const c10::Scalar&);
|
| 58 |
+
using softshrink_fn = void (*)(TensorIteratorBase&, const c10::Scalar&);
|
| 59 |
+
using shrink_backward_fn = void (*)(TensorIteratorBase&, const c10::Scalar&);
|
| 60 |
+
using elu_fn = void (*)(TensorIteratorBase&, const c10::Scalar&, const c10::Scalar&, const c10::Scalar&);
|
| 61 |
+
using elu_backward_fn = void (*)(TensorIteratorBase&, const c10::Scalar&, const c10::Scalar&, const c10::Scalar&, bool);
|
| 62 |
+
using leaky_relu_fn = void (*)(TensorIteratorBase&, const c10::Scalar&);
|
| 63 |
+
using leaky_relu_backward_fn = void (*)(TensorIteratorBase&, const c10::Scalar&);
|
| 64 |
+
using log_sigmoid_cpu_fn = void (*)(TensorBase&, TensorBase&, const TensorBase&);
|
| 65 |
+
using gelu_fn = void (*)(TensorIteratorBase&, GeluType);
|
| 66 |
+
using gelu_backward_fn = void (*)(TensorIteratorBase&, GeluType);
|
| 67 |
+
using glu_jvp_fn = void (*)(TensorIteratorBase&);
|
| 68 |
+
|
| 69 |
+
DECLARE_DISPATCH(elu_fn, elu_stub);
|
| 70 |
+
DECLARE_DISPATCH(elu_backward_fn, elu_backward_stub);
|
| 71 |
+
DECLARE_DISPATCH(softplus_fn, softplus_stub);
|
| 72 |
+
DECLARE_DISPATCH(softplus_backward_fn, softplus_backward_stub);
|
| 73 |
+
DECLARE_DISPATCH(log_sigmoid_cpu_fn, log_sigmoid_cpu_stub);
|
| 74 |
+
DECLARE_DISPATCH(activation_backward_fn, log_sigmoid_backward_stub);
|
| 75 |
+
DECLARE_DISPATCH(threshold_fn, threshold_stub);
|
| 76 |
+
DECLARE_DISPATCH(gelu_fn, GeluKernel);
|
| 77 |
+
DECLARE_DISPATCH(gelu_backward_fn, GeluBackwardKernel);
|
| 78 |
+
DECLARE_DISPATCH(hardtanh_backward_fn, hardtanh_backward_stub);
|
| 79 |
+
DECLARE_DISPATCH(hardsigmoid_fn, hardsigmoid_stub);
|
| 80 |
+
DECLARE_DISPATCH(hardsigmoid_backward_fn, hardsigmoid_backward_stub);
|
| 81 |
+
DECLARE_DISPATCH(hardswish_fn, hardswish_stub);
|
| 82 |
+
DECLARE_DISPATCH(hardswish_backward_fn, hardswish_backward_stub);
|
| 83 |
+
DECLARE_DISPATCH(shrink_fn, hardshrink_stub);
|
| 84 |
+
DECLARE_DISPATCH(softshrink_fn, softshrink_stub);
|
| 85 |
+
DECLARE_DISPATCH(shrink_backward_fn, shrink_backward_stub);
|
| 86 |
+
DECLARE_DISPATCH(leaky_relu_fn, leaky_relu_stub);
|
| 87 |
+
DECLARE_DISPATCH(leaky_relu_backward_fn, leaky_relu_backward_stub);
|
| 88 |
+
DECLARE_DISPATCH(structured_activation_fn, glu_stub);
|
| 89 |
+
DECLARE_DISPATCH(activation_backward_fn, glu_backward_stub);
|
| 90 |
+
DECLARE_DISPATCH(glu_jvp_fn, glu_jvp_stub);
|
| 91 |
+
DECLARE_DISPATCH(structured_activation_fn, silu_stub);
|
| 92 |
+
DECLARE_DISPATCH(structured_activation_backward_fn, silu_backward_stub);
|
| 93 |
+
DECLARE_DISPATCH(structured_activation_fn, mish_stub);
|
| 94 |
+
DECLARE_DISPATCH(activation_backward_fn, mish_backward_stub);
|
| 95 |
+
DECLARE_DISPATCH(activation_fn, prelu_stub);
|
| 96 |
+
DECLARE_DISPATCH(activation_backward_fn, prelu_backward_stub);
|
| 97 |
+
|
| 98 |
+
} // namespace at::native
|
.venv/lib/python3.11/site-packages/torch/include/ATen/native/AdaptivePooling.h
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/core/Tensor.h>
|
| 4 |
+
#include <ATen/native/DispatchStub.h>
|
| 5 |
+
#include <c10/util/ArrayRef.h>
|
| 6 |
+
#include <c10/util/irange.h>
|
| 7 |
+
#include <cmath>
|
| 8 |
+
|
| 9 |
+
namespace at::native {
|
| 10 |
+
|
| 11 |
+
using adaptive_avg_pooling2d_fn = void(*)(Tensor& output, const Tensor& input, IntArrayRef output_size);
|
| 12 |
+
using adaptive_avg_pooling2d_backward_fn = void(*)(Tensor& grad_input, const Tensor& grad_output);
|
| 13 |
+
DECLARE_DISPATCH(adaptive_avg_pooling2d_fn, adaptive_avg_pool2d_kernel);
|
| 14 |
+
DECLARE_DISPATCH(adaptive_avg_pooling2d_backward_fn, adaptive_avg_pool2d_backward_kernel);
|
| 15 |
+
|
| 16 |
+
using adaptive_max_pooling2d_fn = void(*)(const Tensor& output, const Tensor& indices, const Tensor& input, IntArrayRef output_size);
|
| 17 |
+
using adaptive_max_pooling2d_backward_fn = void(*)(const Tensor& grad_input, const Tensor& grad_output, const Tensor& indices);
|
| 18 |
+
DECLARE_DISPATCH(adaptive_max_pooling2d_fn, adaptive_max_pool2d_kernel);
|
| 19 |
+
DECLARE_DISPATCH(adaptive_max_pooling2d_backward_fn, adaptive_max_pool2d_backward_kernel);
|
| 20 |
+
|
| 21 |
+
using adaptive_avg_pooling3d_fn = void(*)(Tensor& output, const Tensor& input, IntArrayRef output_size);
|
| 22 |
+
using adaptive_avg_pooling3d_backward_fn = void(*)(Tensor& grad_input, const Tensor& grad_output);
|
| 23 |
+
DECLARE_DISPATCH(adaptive_avg_pooling3d_fn, adaptive_avg_pool3d_kernel);
|
| 24 |
+
DECLARE_DISPATCH(adaptive_avg_pooling3d_backward_fn, adaptive_avg_pool3d_backward_kernel);
|
| 25 |
+
|
| 26 |
+
using adaptive_max_pooling3d_fn = void(*)(const Tensor& output, const Tensor& indices, const Tensor& input, IntArrayRef output_size);
|
| 27 |
+
using adaptive_max_pooling3d_backward_fn = void(*)(const Tensor& grad_input, const Tensor& grad_output, const Tensor& indices);
|
| 28 |
+
DECLARE_DISPATCH(adaptive_max_pooling3d_fn, adaptive_max_pool3d_kernel);
|
| 29 |
+
DECLARE_DISPATCH(adaptive_max_pooling3d_backward_fn, adaptive_max_pool3d_backward_kernel);
|
| 30 |
+
|
| 31 |
+
inline int64_t start_index(int64_t a, int64_t b, int64_t c) {
|
| 32 |
+
return (a / b) * c + ((a % b) * c) / b;
|
| 33 |
+
}
|
| 34 |
+
|
| 35 |
+
inline int64_t end_index(int64_t a, int64_t b, int64_t c) {
|
| 36 |
+
return 1 + ((a + 1) * c - 1) / b;
|
| 37 |
+
}
|
| 38 |
+
|
| 39 |
+
inline void adaptive_pool_empty_output_check(const Tensor& gradOutput_, const char* arg_name) {
|
| 40 |
+
int64_t ndim = gradOutput_.ndimension();
|
| 41 |
+
for (const auto i : c10::irange(1, ndim)) {
|
| 42 |
+
TORCH_CHECK(gradOutput_.size(i) > 0,
|
| 43 |
+
arg_name, "(): Expected grad_output to have non-zero size for non-batch dimensions, "
|
| 44 |
+
"but grad_output has sizes ", gradOutput_.sizes(), " with dimension ", i,
|
| 45 |
+
" being empty");
|
| 46 |
+
}
|
| 47 |
+
}
|
| 48 |
+
|
| 49 |
+
} // namespace at::native
|
.venv/lib/python3.11/site-packages/torch/include/ATen/native/BatchLinearAlgebra.h
ADDED
|
@@ -0,0 +1,321 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <optional>
|
| 4 |
+
#include <c10/util/string_view.h>
|
| 5 |
+
#include <ATen/Config.h>
|
| 6 |
+
#include <ATen/native/DispatchStub.h>
|
| 7 |
+
|
| 8 |
+
// Forward declare TI
|
| 9 |
+
namespace at {
|
| 10 |
+
class Tensor;
|
| 11 |
+
struct TensorIterator;
|
| 12 |
+
|
| 13 |
+
namespace native {
|
| 14 |
+
enum class TransposeType;
|
| 15 |
+
}
|
| 16 |
+
|
| 17 |
+
}
|
| 18 |
+
|
| 19 |
+
namespace at::native {
|
| 20 |
+
|
| 21 |
+
enum class LapackLstsqDriverType : int64_t { Gels, Gelsd, Gelsy, Gelss};
|
| 22 |
+
|
| 23 |
+
#if AT_BUILD_WITH_LAPACK()
|
| 24 |
+
// Define per-batch functions to be used in the implementation of batched
|
| 25 |
+
// linear algebra operations
|
| 26 |
+
|
| 27 |
+
template <class scalar_t>
|
| 28 |
+
void lapackCholesky(char uplo, int n, scalar_t *a, int lda, int *info);
|
| 29 |
+
|
| 30 |
+
template <class scalar_t>
|
| 31 |
+
void lapackCholeskyInverse(char uplo, int n, scalar_t *a, int lda, int *info);
|
| 32 |
+
|
| 33 |
+
template <class scalar_t, class value_t=scalar_t>
|
| 34 |
+
void lapackEig(char jobvl, char jobvr, int n, scalar_t *a, int lda, scalar_t *w, scalar_t* vl, int ldvl, scalar_t *vr, int ldvr, scalar_t *work, int lwork, value_t *rwork, int *info);
|
| 35 |
+
|
| 36 |
+
template <class scalar_t>
|
| 37 |
+
void lapackGeqrf(int m, int n, scalar_t *a, int lda, scalar_t *tau, scalar_t *work, int lwork, int *info);
|
| 38 |
+
|
| 39 |
+
template <class scalar_t>
|
| 40 |
+
void lapackOrgqr(int m, int n, int k, scalar_t *a, int lda, scalar_t *tau, scalar_t *work, int lwork, int *info);
|
| 41 |
+
|
| 42 |
+
template <class scalar_t>
|
| 43 |
+
void lapackOrmqr(char side, char trans, int m, int n, int k, scalar_t *a, int lda, scalar_t *tau, scalar_t *c, int ldc, scalar_t *work, int lwork, int *info);
|
| 44 |
+
|
| 45 |
+
template <class scalar_t, class value_t = scalar_t>
|
| 46 |
+
void lapackSyevd(char jobz, char uplo, int n, scalar_t* a, int lda, value_t* w, scalar_t* work, int lwork, value_t* rwork, int lrwork, int* iwork, int liwork, int* info);
|
| 47 |
+
|
| 48 |
+
template <class scalar_t>
|
| 49 |
+
void lapackGels(char trans, int m, int n, int nrhs,
|
| 50 |
+
scalar_t *a, int lda, scalar_t *b, int ldb,
|
| 51 |
+
scalar_t *work, int lwork, int *info);
|
| 52 |
+
|
| 53 |
+
template <class scalar_t, class value_t = scalar_t>
|
| 54 |
+
void lapackGelsd(int m, int n, int nrhs,
|
| 55 |
+
scalar_t *a, int lda, scalar_t *b, int ldb,
|
| 56 |
+
value_t *s, value_t rcond, int *rank,
|
| 57 |
+
scalar_t* work, int lwork,
|
| 58 |
+
value_t *rwork, int* iwork, int *info);
|
| 59 |
+
|
| 60 |
+
template <class scalar_t, class value_t = scalar_t>
|
| 61 |
+
void lapackGelsy(int m, int n, int nrhs,
|
| 62 |
+
scalar_t *a, int lda, scalar_t *b, int ldb,
|
| 63 |
+
int *jpvt, value_t rcond, int *rank,
|
| 64 |
+
scalar_t *work, int lwork, value_t* rwork, int *info);
|
| 65 |
+
|
| 66 |
+
template <class scalar_t, class value_t = scalar_t>
|
| 67 |
+
void lapackGelss(int m, int n, int nrhs,
|
| 68 |
+
scalar_t *a, int lda, scalar_t *b, int ldb,
|
| 69 |
+
value_t *s, value_t rcond, int *rank,
|
| 70 |
+
scalar_t *work, int lwork,
|
| 71 |
+
value_t *rwork, int *info);
|
| 72 |
+
|
| 73 |
+
template <LapackLstsqDriverType, class scalar_t, class value_t = scalar_t>
|
| 74 |
+
struct lapackLstsq_impl;
|
| 75 |
+
|
| 76 |
+
template <class scalar_t, class value_t>
|
| 77 |
+
struct lapackLstsq_impl<LapackLstsqDriverType::Gels, scalar_t, value_t> {
|
| 78 |
+
static void call(
|
| 79 |
+
char trans, int m, int n, int nrhs,
|
| 80 |
+
scalar_t *a, int lda, scalar_t *b, int ldb,
|
| 81 |
+
scalar_t *work, int lwork, int *info, // Gels flavor
|
| 82 |
+
int *jpvt, value_t rcond, int *rank, value_t* rwork, // Gelsy flavor
|
| 83 |
+
value_t *s, // Gelss flavor
|
| 84 |
+
int *iwork // Gelsd flavor
|
| 85 |
+
) {
|
| 86 |
+
lapackGels<scalar_t>(
|
| 87 |
+
trans, m, n, nrhs,
|
| 88 |
+
a, lda, b, ldb,
|
| 89 |
+
work, lwork, info);
|
| 90 |
+
}
|
| 91 |
+
};
|
| 92 |
+
|
| 93 |
+
template <class scalar_t, class value_t>
|
| 94 |
+
struct lapackLstsq_impl<LapackLstsqDriverType::Gelsy, scalar_t, value_t> {
|
| 95 |
+
static void call(
|
| 96 |
+
char trans, int m, int n, int nrhs,
|
| 97 |
+
scalar_t *a, int lda, scalar_t *b, int ldb,
|
| 98 |
+
scalar_t *work, int lwork, int *info, // Gels flavor
|
| 99 |
+
int *jpvt, value_t rcond, int *rank, value_t* rwork, // Gelsy flavor
|
| 100 |
+
value_t *s, // Gelss flavor
|
| 101 |
+
int *iwork // Gelsd flavor
|
| 102 |
+
) {
|
| 103 |
+
lapackGelsy<scalar_t, value_t>(
|
| 104 |
+
m, n, nrhs,
|
| 105 |
+
a, lda, b, ldb,
|
| 106 |
+
jpvt, rcond, rank,
|
| 107 |
+
work, lwork, rwork, info);
|
| 108 |
+
}
|
| 109 |
+
};
|
| 110 |
+
|
| 111 |
+
template <class scalar_t, class value_t>
|
| 112 |
+
struct lapackLstsq_impl<LapackLstsqDriverType::Gelsd, scalar_t, value_t> {
|
| 113 |
+
static void call(
|
| 114 |
+
char trans, int m, int n, int nrhs,
|
| 115 |
+
scalar_t *a, int lda, scalar_t *b, int ldb,
|
| 116 |
+
scalar_t *work, int lwork, int *info, // Gels flavor
|
| 117 |
+
int *jpvt, value_t rcond, int *rank, value_t* rwork, // Gelsy flavor
|
| 118 |
+
value_t *s, // Gelss flavor
|
| 119 |
+
int *iwork // Gelsd flavor
|
| 120 |
+
) {
|
| 121 |
+
lapackGelsd<scalar_t, value_t>(
|
| 122 |
+
m, n, nrhs,
|
| 123 |
+
a, lda, b, ldb,
|
| 124 |
+
s, rcond, rank,
|
| 125 |
+
work, lwork,
|
| 126 |
+
rwork, iwork, info);
|
| 127 |
+
}
|
| 128 |
+
};
|
| 129 |
+
|
| 130 |
+
template <class scalar_t, class value_t>
|
| 131 |
+
struct lapackLstsq_impl<LapackLstsqDriverType::Gelss, scalar_t, value_t> {
|
| 132 |
+
static void call(
|
| 133 |
+
char trans, int m, int n, int nrhs,
|
| 134 |
+
scalar_t *a, int lda, scalar_t *b, int ldb,
|
| 135 |
+
scalar_t *work, int lwork, int *info, // Gels flavor
|
| 136 |
+
int *jpvt, value_t rcond, int *rank, value_t* rwork, // Gelsy flavor
|
| 137 |
+
value_t *s, // Gelss flavor
|
| 138 |
+
int *iwork // Gelsd flavor
|
| 139 |
+
) {
|
| 140 |
+
lapackGelss<scalar_t, value_t>(
|
| 141 |
+
m, n, nrhs,
|
| 142 |
+
a, lda, b, ldb,
|
| 143 |
+
s, rcond, rank,
|
| 144 |
+
work, lwork,
|
| 145 |
+
rwork, info);
|
| 146 |
+
}
|
| 147 |
+
};
|
| 148 |
+
|
| 149 |
+
template <LapackLstsqDriverType driver_type, class scalar_t, class value_t = scalar_t>
|
| 150 |
+
void lapackLstsq(
|
| 151 |
+
char trans, int m, int n, int nrhs,
|
| 152 |
+
scalar_t *a, int lda, scalar_t *b, int ldb,
|
| 153 |
+
scalar_t *work, int lwork, int *info, // Gels flavor
|
| 154 |
+
int *jpvt, value_t rcond, int *rank, value_t* rwork, // Gelsy flavor
|
| 155 |
+
value_t *s, // Gelss flavor
|
| 156 |
+
int *iwork // Gelsd flavor
|
| 157 |
+
) {
|
| 158 |
+
lapackLstsq_impl<driver_type, scalar_t, value_t>::call(
|
| 159 |
+
trans, m, n, nrhs,
|
| 160 |
+
a, lda, b, ldb,
|
| 161 |
+
work, lwork, info,
|
| 162 |
+
jpvt, rcond, rank, rwork,
|
| 163 |
+
s,
|
| 164 |
+
iwork);
|
| 165 |
+
}
|
| 166 |
+
|
| 167 |
+
template <class scalar_t>
|
| 168 |
+
void lapackLuSolve(char trans, int n, int nrhs, scalar_t *a, int lda, int *ipiv, scalar_t *b, int ldb, int *info);
|
| 169 |
+
|
| 170 |
+
template <class scalar_t>
|
| 171 |
+
void lapackLu(int m, int n, scalar_t *a, int lda, int *ipiv, int *info);
|
| 172 |
+
|
| 173 |
+
template <class scalar_t>
|
| 174 |
+
void lapackLdlHermitian(
|
| 175 |
+
char uplo,
|
| 176 |
+
int n,
|
| 177 |
+
scalar_t* a,
|
| 178 |
+
int lda,
|
| 179 |
+
int* ipiv,
|
| 180 |
+
scalar_t* work,
|
| 181 |
+
int lwork,
|
| 182 |
+
int* info);
|
| 183 |
+
|
| 184 |
+
template <class scalar_t>
|
| 185 |
+
void lapackLdlSymmetric(
|
| 186 |
+
char uplo,
|
| 187 |
+
int n,
|
| 188 |
+
scalar_t* a,
|
| 189 |
+
int lda,
|
| 190 |
+
int* ipiv,
|
| 191 |
+
scalar_t* work,
|
| 192 |
+
int lwork,
|
| 193 |
+
int* info);
|
| 194 |
+
|
| 195 |
+
template <class scalar_t>
|
| 196 |
+
void lapackLdlSolveHermitian(
|
| 197 |
+
char uplo,
|
| 198 |
+
int n,
|
| 199 |
+
int nrhs,
|
| 200 |
+
scalar_t* a,
|
| 201 |
+
int lda,
|
| 202 |
+
int* ipiv,
|
| 203 |
+
scalar_t* b,
|
| 204 |
+
int ldb,
|
| 205 |
+
int* info);
|
| 206 |
+
|
| 207 |
+
template <class scalar_t>
|
| 208 |
+
void lapackLdlSolveSymmetric(
|
| 209 |
+
char uplo,
|
| 210 |
+
int n,
|
| 211 |
+
int nrhs,
|
| 212 |
+
scalar_t* a,
|
| 213 |
+
int lda,
|
| 214 |
+
int* ipiv,
|
| 215 |
+
scalar_t* b,
|
| 216 |
+
int ldb,
|
| 217 |
+
int* info);
|
| 218 |
+
|
| 219 |
+
template<class scalar_t, class value_t=scalar_t>
|
| 220 |
+
void lapackSvd(char jobz, int m, int n, scalar_t *a, int lda, value_t *s, scalar_t *u, int ldu, scalar_t *vt, int ldvt, scalar_t *work, int lwork, value_t *rwork, int *iwork, int *info);
|
| 221 |
+
#endif
|
| 222 |
+
|
| 223 |
+
#if AT_BUILD_WITH_BLAS()
|
| 224 |
+
template <class scalar_t>
|
| 225 |
+
void blasTriangularSolve(char side, char uplo, char trans, char diag, int n, int nrhs, scalar_t* a, int lda, scalar_t* b, int ldb);
|
| 226 |
+
#endif
|
| 227 |
+
|
| 228 |
+
using cholesky_fn = void (*)(const Tensor& /*input*/, const Tensor& /*info*/, bool /*upper*/);
|
| 229 |
+
DECLARE_DISPATCH(cholesky_fn, cholesky_stub);
|
| 230 |
+
|
| 231 |
+
using cholesky_inverse_fn = Tensor& (*)(Tensor& /*result*/, Tensor& /*infos*/, bool /*upper*/);
|
| 232 |
+
|
| 233 |
+
DECLARE_DISPATCH(cholesky_inverse_fn, cholesky_inverse_stub);
|
| 234 |
+
|
| 235 |
+
using linalg_eig_fn = void (*)(Tensor& /*eigenvalues*/, Tensor& /*eigenvectors*/, Tensor& /*infos*/, const Tensor& /*input*/, bool /*compute_eigenvectors*/);
|
| 236 |
+
|
| 237 |
+
DECLARE_DISPATCH(linalg_eig_fn, linalg_eig_stub);
|
| 238 |
+
|
| 239 |
+
using geqrf_fn = void (*)(const Tensor& /*input*/, const Tensor& /*tau*/);
|
| 240 |
+
DECLARE_DISPATCH(geqrf_fn, geqrf_stub);
|
| 241 |
+
|
| 242 |
+
using orgqr_fn = Tensor& (*)(Tensor& /*result*/, const Tensor& /*tau*/);
|
| 243 |
+
DECLARE_DISPATCH(orgqr_fn, orgqr_stub);
|
| 244 |
+
|
| 245 |
+
using ormqr_fn = void (*)(const Tensor& /*input*/, const Tensor& /*tau*/, const Tensor& /*other*/, bool /*left*/, bool /*transpose*/);
|
| 246 |
+
DECLARE_DISPATCH(ormqr_fn, ormqr_stub);
|
| 247 |
+
|
| 248 |
+
using linalg_eigh_fn = void (*)(
|
| 249 |
+
const Tensor& /*eigenvalues*/,
|
| 250 |
+
const Tensor& /*eigenvectors*/,
|
| 251 |
+
const Tensor& /*infos*/,
|
| 252 |
+
bool /*upper*/,
|
| 253 |
+
bool /*compute_eigenvectors*/);
|
| 254 |
+
DECLARE_DISPATCH(linalg_eigh_fn, linalg_eigh_stub);
|
| 255 |
+
|
| 256 |
+
using lstsq_fn = void (*)(
|
| 257 |
+
const Tensor& /*a*/,
|
| 258 |
+
Tensor& /*b*/,
|
| 259 |
+
Tensor& /*rank*/,
|
| 260 |
+
Tensor& /*singular_values*/,
|
| 261 |
+
Tensor& /*infos*/,
|
| 262 |
+
double /*rcond*/,
|
| 263 |
+
std::string /*driver_name*/);
|
| 264 |
+
DECLARE_DISPATCH(lstsq_fn, lstsq_stub);
|
| 265 |
+
|
| 266 |
+
using triangular_solve_fn = void (*)(
|
| 267 |
+
const Tensor& /*A*/,
|
| 268 |
+
const Tensor& /*B*/,
|
| 269 |
+
bool /*left*/,
|
| 270 |
+
bool /*upper*/,
|
| 271 |
+
TransposeType /*transpose*/,
|
| 272 |
+
bool /*unitriangular*/);
|
| 273 |
+
DECLARE_DISPATCH(triangular_solve_fn, triangular_solve_stub);
|
| 274 |
+
|
| 275 |
+
using lu_factor_fn = void (*)(
|
| 276 |
+
const Tensor& /*input*/,
|
| 277 |
+
const Tensor& /*pivots*/,
|
| 278 |
+
const Tensor& /*infos*/,
|
| 279 |
+
bool /*compute_pivots*/);
|
| 280 |
+
DECLARE_DISPATCH(lu_factor_fn, lu_factor_stub);
|
| 281 |
+
|
| 282 |
+
using unpack_pivots_fn = void(*)(
|
| 283 |
+
TensorIterator& iter,
|
| 284 |
+
const int64_t dim_size,
|
| 285 |
+
const int64_t max_pivot);
|
| 286 |
+
DECLARE_DISPATCH(unpack_pivots_fn, unpack_pivots_stub);
|
| 287 |
+
|
| 288 |
+
using lu_solve_fn = void (*)(
|
| 289 |
+
const Tensor& /*LU*/,
|
| 290 |
+
const Tensor& /*pivots*/,
|
| 291 |
+
const Tensor& /*B*/,
|
| 292 |
+
TransposeType /*trans*/);
|
| 293 |
+
DECLARE_DISPATCH(lu_solve_fn, lu_solve_stub);
|
| 294 |
+
|
| 295 |
+
using ldl_factor_fn = void (*)(
|
| 296 |
+
const Tensor& /*LD*/,
|
| 297 |
+
const Tensor& /*pivots*/,
|
| 298 |
+
const Tensor& /*info*/,
|
| 299 |
+
bool /*upper*/,
|
| 300 |
+
bool /*hermitian*/);
|
| 301 |
+
DECLARE_DISPATCH(ldl_factor_fn, ldl_factor_stub);
|
| 302 |
+
|
| 303 |
+
using svd_fn = void (*)(
|
| 304 |
+
const Tensor& /*A*/,
|
| 305 |
+
const bool /*full_matrices*/,
|
| 306 |
+
const bool /*compute_uv*/,
|
| 307 |
+
const std::optional<c10::string_view>& /*driver*/,
|
| 308 |
+
const Tensor& /*U*/,
|
| 309 |
+
const Tensor& /*S*/,
|
| 310 |
+
const Tensor& /*Vh*/,
|
| 311 |
+
const Tensor& /*info*/);
|
| 312 |
+
DECLARE_DISPATCH(svd_fn, svd_stub);
|
| 313 |
+
|
| 314 |
+
using ldl_solve_fn = void (*)(
|
| 315 |
+
const Tensor& /*LD*/,
|
| 316 |
+
const Tensor& /*pivots*/,
|
| 317 |
+
const Tensor& /*result*/,
|
| 318 |
+
bool /*upper*/,
|
| 319 |
+
bool /*hermitian*/);
|
| 320 |
+
DECLARE_DISPATCH(ldl_solve_fn, ldl_solve_stub);
|
| 321 |
+
} // namespace at::native
|
.venv/lib/python3.11/site-packages/torch/include/ATen/native/BinaryOps.h
ADDED
|
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/core/TensorBase.h>
|
| 4 |
+
#include <ATen/native/DispatchStub.h>
|
| 5 |
+
#include <c10/core/Scalar.h>
|
| 6 |
+
#include <c10/util/TypeSafeSignMath.h>
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
namespace at {
|
| 10 |
+
struct TensorIterator;
|
| 11 |
+
struct TensorIteratorBase;
|
| 12 |
+
}
|
| 13 |
+
|
| 14 |
+
namespace at::native {
|
| 15 |
+
|
| 16 |
+
inline void alpha_check(const ScalarType dtype, const Scalar& alpha) {
|
| 17 |
+
TORCH_CHECK(! alpha.isBoolean() || dtype == ScalarType::Bool,
|
| 18 |
+
"Boolean alpha only supported for Boolean results.");
|
| 19 |
+
TORCH_CHECK(isFloatingType(dtype) || isComplexType(dtype)
|
| 20 |
+
|| alpha.isIntegral(true),
|
| 21 |
+
"For integral input tensors, argument alpha must not be a floating point number.");
|
| 22 |
+
TORCH_CHECK(isComplexType(dtype) || !alpha.isComplex(),
|
| 23 |
+
"For non-complex input tensors, argument alpha must not be a complex number.")
|
| 24 |
+
}
|
| 25 |
+
|
| 26 |
+
// Basic checking for all sub functions.
|
| 27 |
+
inline void sub_check(const TensorBase& self, const TensorBase& other) {
|
| 28 |
+
TORCH_CHECK(self.scalar_type() != kBool || other.scalar_type() != kBool,
|
| 29 |
+
"Subtraction, the `-` operator, with two bool tensors is not supported. "
|
| 30 |
+
"Use the `^` or `logical_xor()` operator instead.")
|
| 31 |
+
TORCH_CHECK(self.scalar_type() != kBool && other.scalar_type() != kBool,
|
| 32 |
+
"Subtraction, the `-` operator, with a bool tensor is not supported. "
|
| 33 |
+
"If you are trying to invert a mask, use the `~` or `logical_not()` operator instead.");
|
| 34 |
+
}
|
| 35 |
+
|
| 36 |
+
inline void sub_check(const TensorBase& self, const Scalar& scalar) {
|
| 37 |
+
TORCH_CHECK(self.scalar_type() != kBool || !scalar.isBoolean(),
|
| 38 |
+
"Subtraction, the `-` operator, with two bool tensors is not supported. "
|
| 39 |
+
"Use the `^` or `logical_xor()` operator instead.")
|
| 40 |
+
TORCH_CHECK(self.scalar_type() != kBool && !scalar.isBoolean(),
|
| 41 |
+
"Subtraction, the `-` operator, with a bool tensor is not supported. "
|
| 42 |
+
"If you are trying to invert a mask, use the `~` or `logical_not()` operator instead.");
|
| 43 |
+
}
|
| 44 |
+
|
| 45 |
+
using structured_binary_fn_alpha = void(*)(TensorIteratorBase&, const Scalar& alpha);
|
| 46 |
+
using structured_binary_fn_double = void(*)(TensorIteratorBase&, double);
|
| 47 |
+
using structured_binary_fn = void(*)(TensorIteratorBase&);
|
| 48 |
+
|
| 49 |
+
using binary_fn_alpha = void(*)(TensorIteratorBase&, const Scalar& alpha);
|
| 50 |
+
using binary_fn_double = void(*)(TensorIterator&, double);
|
| 51 |
+
using binary_fn = void(*)(TensorIterator&);
|
| 52 |
+
using binary_clamp_fn_alpha =
|
| 53 |
+
void(*)(TensorIterator&, const Scalar& alpha, const Scalar& min_val, const Scalar& max_val);
|
| 54 |
+
|
| 55 |
+
// NB: codegenned
|
| 56 |
+
DECLARE_DISPATCH(structured_binary_fn_alpha, add_stub);
|
| 57 |
+
|
| 58 |
+
DECLARE_DISPATCH(binary_clamp_fn_alpha, add_clamp_stub);
|
| 59 |
+
DECLARE_DISPATCH(structured_binary_fn_alpha, sub_stub);
|
| 60 |
+
DECLARE_DISPATCH(structured_binary_fn, mul_stub);
|
| 61 |
+
DECLARE_DISPATCH(structured_binary_fn, div_true_stub);
|
| 62 |
+
DECLARE_DISPATCH(structured_binary_fn, div_floor_stub);
|
| 63 |
+
DECLARE_DISPATCH(structured_binary_fn, div_trunc_stub);
|
| 64 |
+
DECLARE_DISPATCH(structured_binary_fn, atan2_stub);
|
| 65 |
+
DECLARE_DISPATCH(structured_binary_fn, remainder_stub);
|
| 66 |
+
DECLARE_DISPATCH(structured_binary_fn, bitwise_and_stub);
|
| 67 |
+
DECLARE_DISPATCH(structured_binary_fn, bitwise_or_stub);
|
| 68 |
+
DECLARE_DISPATCH(structured_binary_fn, bitwise_xor_stub);
|
| 69 |
+
DECLARE_DISPATCH(structured_binary_fn, lshift_stub);
|
| 70 |
+
DECLARE_DISPATCH(structured_binary_fn, rshift_stub);
|
| 71 |
+
DECLARE_DISPATCH(binary_fn, logical_xor_stub);
|
| 72 |
+
DECLARE_DISPATCH(binary_fn, logical_and_stub);
|
| 73 |
+
DECLARE_DISPATCH(binary_fn, logical_or_stub);
|
| 74 |
+
DECLARE_DISPATCH(structured_binary_fn, lt_stub);
|
| 75 |
+
DECLARE_DISPATCH(structured_binary_fn, le_stub);
|
| 76 |
+
DECLARE_DISPATCH(structured_binary_fn, gt_stub);
|
| 77 |
+
DECLARE_DISPATCH(structured_binary_fn, ge_stub);
|
| 78 |
+
DECLARE_DISPATCH(structured_binary_fn, eq_stub);
|
| 79 |
+
DECLARE_DISPATCH(structured_binary_fn, ne_stub);
|
| 80 |
+
DECLARE_DISPATCH(binary_fn, max_elementwise_stub);
|
| 81 |
+
DECLARE_DISPATCH(binary_fn, min_elementwise_stub);
|
| 82 |
+
DECLARE_DISPATCH(structured_binary_fn, maximum_stub);
|
| 83 |
+
DECLARE_DISPATCH(structured_binary_fn, minimum_stub);
|
| 84 |
+
DECLARE_DISPATCH(structured_binary_fn, fmax_stub);
|
| 85 |
+
DECLARE_DISPATCH(structured_binary_fn, fmin_stub);
|
| 86 |
+
DECLARE_DISPATCH(structured_binary_fn_double, smooth_l1_stub);
|
| 87 |
+
DECLARE_DISPATCH(binary_fn_double, huber_stub);
|
| 88 |
+
DECLARE_DISPATCH(structured_binary_fn, sigmoid_backward_stub);
|
| 89 |
+
DECLARE_DISPATCH(binary_fn_alpha, logit_backward_stub);
|
| 90 |
+
DECLARE_DISPATCH(structured_binary_fn, tanh_backward_stub);
|
| 91 |
+
DECLARE_DISPATCH(structured_binary_fn, mse_stub);
|
| 92 |
+
DECLARE_DISPATCH(structured_binary_fn, fmod_stub);
|
| 93 |
+
DECLARE_DISPATCH(structured_binary_fn, logaddexp_stub);
|
| 94 |
+
DECLARE_DISPATCH(structured_binary_fn, logaddexp2_stub);
|
| 95 |
+
DECLARE_DISPATCH(structured_binary_fn, gcd_stub);
|
| 96 |
+
DECLARE_DISPATCH(structured_binary_fn, lcm_stub);
|
| 97 |
+
DECLARE_DISPATCH(structured_binary_fn, hypot_stub);
|
| 98 |
+
DECLARE_DISPATCH(structured_binary_fn, igamma_stub);
|
| 99 |
+
DECLARE_DISPATCH(structured_binary_fn, igammac_stub);
|
| 100 |
+
DECLARE_DISPATCH(structured_binary_fn, nextafter_stub);
|
| 101 |
+
DECLARE_DISPATCH(structured_binary_fn, heaviside_stub);
|
| 102 |
+
DECLARE_DISPATCH(structured_binary_fn, copysign_stub);
|
| 103 |
+
DECLARE_DISPATCH(structured_binary_fn, xlogy_stub);
|
| 104 |
+
DECLARE_DISPATCH(structured_binary_fn, xlog1py_stub);
|
| 105 |
+
DECLARE_DISPATCH(structured_binary_fn, zeta_stub);
|
| 106 |
+
DECLARE_DISPATCH(structured_binary_fn, chebyshev_polynomial_t_stub);
|
| 107 |
+
DECLARE_DISPATCH(structured_binary_fn, chebyshev_polynomial_u_stub);
|
| 108 |
+
DECLARE_DISPATCH(structured_binary_fn, chebyshev_polynomial_v_stub);
|
| 109 |
+
DECLARE_DISPATCH(structured_binary_fn, chebyshev_polynomial_w_stub);
|
| 110 |
+
DECLARE_DISPATCH(structured_binary_fn, hermite_polynomial_h_stub);
|
| 111 |
+
DECLARE_DISPATCH(structured_binary_fn, hermite_polynomial_he_stub);
|
| 112 |
+
DECLARE_DISPATCH(structured_binary_fn, laguerre_polynomial_l_stub);
|
| 113 |
+
DECLARE_DISPATCH(structured_binary_fn, legendre_polynomial_p_stub);
|
| 114 |
+
DECLARE_DISPATCH(structured_binary_fn, shifted_chebyshev_polynomial_t_stub);
|
| 115 |
+
DECLARE_DISPATCH(structured_binary_fn, shifted_chebyshev_polynomial_u_stub);
|
| 116 |
+
DECLARE_DISPATCH(structured_binary_fn, shifted_chebyshev_polynomial_v_stub);
|
| 117 |
+
DECLARE_DISPATCH(structured_binary_fn, shifted_chebyshev_polynomial_w_stub);
|
| 118 |
+
|
| 119 |
+
} // namespace at::native
|
.venv/lib/python3.11/site-packages/torch/include/ATen/native/ComplexHelper.h
ADDED
|
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/core/Tensor.h>
|
| 4 |
+
#include <c10/util/irange.h>
|
| 5 |
+
|
| 6 |
+
#ifndef AT_PER_OPERATOR_HEADERS
|
| 7 |
+
#include <ATen/NativeFunctions.h>
|
| 8 |
+
#else
|
| 9 |
+
#include <ATen/ops/view_as_real_native.h>
|
| 10 |
+
#include <ATen/ops/view_as_complex_native.h>
|
| 11 |
+
|
| 12 |
+
#include <utility>
|
| 13 |
+
#endif
|
| 14 |
+
|
| 15 |
+
// WARNING: this header contains non-inline functions and should be only
|
| 16 |
+
// included from ONE cpp file
|
| 17 |
+
|
| 18 |
+
namespace at::native {
|
| 19 |
+
|
| 20 |
+
// View tensor with new dtype, storage offset, sizes and strides
|
| 21 |
+
inline Tensor view_tensor(
|
| 22 |
+
const Tensor &tensor, ScalarType dtype,
|
| 23 |
+
c10::SymInt offset, SymIntArrayRef sizes, SymIntArrayRef strides) {
|
| 24 |
+
Storage storage = tensor.storage();
|
| 25 |
+
auto key_set = tensor.key_set().remove(DispatchKey::Conjugate);
|
| 26 |
+
auto new_tensor = detail::make_tensor<TensorImpl>(
|
| 27 |
+
c10::TensorImpl::VIEW, std::move(storage), key_set, scalarTypeToTypeMeta(dtype));
|
| 28 |
+
auto * impl = new_tensor.unsafeGetTensorImpl();
|
| 29 |
+
impl->set_sizes_and_strides(sizes, strides, offset);
|
| 30 |
+
return new_tensor;
|
| 31 |
+
}
|
| 32 |
+
|
| 33 |
+
inline SymDimVector computeStrideForViewAsReal(SymIntArrayRef oldstride) {
|
| 34 |
+
SymDimVector res(oldstride.size() + 1);
|
| 35 |
+
for (const auto i : c10::irange(oldstride.size())) {
|
| 36 |
+
res[i] = oldstride[i] * 2;
|
| 37 |
+
}
|
| 38 |
+
res.back() = 1;
|
| 39 |
+
return res;
|
| 40 |
+
}
|
| 41 |
+
|
| 42 |
+
inline Tensor _view_as_real_physical(const Tensor& self) {
|
| 43 |
+
TORCH_CHECK(self.is_complex(), "view_as_real is only supported for complex tensors");
|
| 44 |
+
auto old_sizes = self.sym_sizes();
|
| 45 |
+
SymDimVector new_sizes(old_sizes.size() + 1);
|
| 46 |
+
std::copy(old_sizes.begin(), old_sizes.end(), new_sizes.begin());
|
| 47 |
+
// last dimension will always have two elements containing the real and imag vals
|
| 48 |
+
new_sizes.back() = 2;
|
| 49 |
+
auto new_strides = computeStrideForViewAsReal(self.sym_strides());
|
| 50 |
+
auto new_storage_offset = self.sym_storage_offset() * 2;
|
| 51 |
+
const auto float_type = c10::toRealValueType(self.scalar_type());
|
| 52 |
+
auto real_tensor = view_tensor(self, float_type, std::move(new_storage_offset), new_sizes, new_strides);
|
| 53 |
+
return real_tensor;
|
| 54 |
+
}
|
| 55 |
+
|
| 56 |
+
// expects as input a complex tensor and returns back a tensor
|
| 57 |
+
// with corresponding real dtype containing the complex values
|
| 58 |
+
// in the last two dimensions
|
| 59 |
+
Tensor view_as_real(const Tensor& self) {
|
| 60 |
+
TORCH_CHECK(!self.is_conj(), "view_as_real doesn't work on unresolved conjugated tensors. To resolve the conjugate tensor so you can view it as real, use self.resolve_conj(); however, be warned that the resulting tensor will NOT alias the original.");
|
| 61 |
+
return _view_as_real_physical(self);
|
| 62 |
+
}
|
| 63 |
+
|
| 64 |
+
inline SymDimVector computeStrideForViewAsComplex(SymIntArrayRef oldstride) {
|
| 65 |
+
const auto dim = oldstride.size();
|
| 66 |
+
TORCH_CHECK(dim > 0 && oldstride[dim - 1] == 1, "Tensor must have a last dimension with stride 1");
|
| 67 |
+
|
| 68 |
+
SymDimVector res(dim - 1);
|
| 69 |
+
for (const auto i : c10::irange(res.size())) {
|
| 70 |
+
TORCH_CHECK(oldstride[i] % 2 == 0, "Tensor must have a stride divisible by 2 for all but last dimension");
|
| 71 |
+
res[i] = oldstride[i] / 2;
|
| 72 |
+
}
|
| 73 |
+
return res;
|
| 74 |
+
}
|
| 75 |
+
|
| 76 |
+
// expects as input a float or double tensor with last dimension of size 2
|
| 77 |
+
// and returns back a tensor with corresponding complex dtype
|
| 78 |
+
Tensor view_as_complex(const Tensor& self) {
|
| 79 |
+
TORCH_CHECK(
|
| 80 |
+
self.scalar_type() == kFloat || self.scalar_type() == kDouble || self.scalar_type() == kHalf,
|
| 81 |
+
"view_as_complex is only supported for half, float and double tensors, but got a tensor of scalar type: ", self.scalar_type());
|
| 82 |
+
|
| 83 |
+
auto old_sizes = self.sym_sizes();
|
| 84 |
+
TORCH_CHECK(!old_sizes.empty(), "Input tensor must have one or more dimensions");
|
| 85 |
+
TORCH_CHECK(old_sizes[old_sizes.size()-1] == 2, "Tensor must have a last dimension of size 2");
|
| 86 |
+
SymDimVector new_sizes(old_sizes.begin(), old_sizes.end() - 1);
|
| 87 |
+
|
| 88 |
+
const auto new_strides = computeStrideForViewAsComplex(self.sym_strides());
|
| 89 |
+
const auto complex_type = c10::toComplexType(self.scalar_type());
|
| 90 |
+
|
| 91 |
+
TORCH_CHECK(self.sym_storage_offset() % 2 == 0, "Tensor must have a storage_offset divisible by 2");
|
| 92 |
+
const auto new_storage_offset = self.sym_storage_offset() / 2;
|
| 93 |
+
|
| 94 |
+
return view_tensor(self, complex_type, new_storage_offset, new_sizes, new_strides);
|
| 95 |
+
}
|
| 96 |
+
|
| 97 |
+
} // namespace at::native
|
.venv/lib/python3.11/site-packages/torch/include/ATen/native/Distance.h
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/native/DispatchStub.h>
|
| 4 |
+
|
| 5 |
+
namespace at {
|
| 6 |
+
class Tensor;
|
| 7 |
+
|
| 8 |
+
namespace native {
|
| 9 |
+
|
| 10 |
+
using pdist_forward_fn = void(*)(Tensor&, const Tensor&, const double p);
|
| 11 |
+
using pdist_backward_fn = void(*)(Tensor&, const Tensor&, const Tensor&, const double p, const Tensor&);
|
| 12 |
+
using cdist_fn = void(*)(Tensor&, const Tensor&, const Tensor&, const double p);
|
| 13 |
+
using cdist_backward_fn = void(*)(Tensor&, const Tensor&, const Tensor&, const Tensor&, const double p, const Tensor&);
|
| 14 |
+
|
| 15 |
+
DECLARE_DISPATCH(pdist_forward_fn, pdist_forward_stub);
|
| 16 |
+
DECLARE_DISPATCH(pdist_backward_fn, pdist_backward_stub);
|
| 17 |
+
DECLARE_DISPATCH(cdist_fn, cdist_stub);
|
| 18 |
+
DECLARE_DISPATCH(cdist_backward_fn, cdist_backward_stub);
|
| 19 |
+
|
| 20 |
+
}} // namespace at::native
|
.venv/lib/python3.11/site-packages/torch/include/ATen/native/Fill.h
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Functions that fill Tensors with constants. Implementations are in Fill.cpp.
|
| 2 |
+
|
| 3 |
+
#pragma once
|
| 4 |
+
|
| 5 |
+
#include <ATen/native/DispatchStub.h>
|
| 6 |
+
|
| 7 |
+
namespace c10 {
|
| 8 |
+
class Scalar;
|
| 9 |
+
}
|
| 10 |
+
|
| 11 |
+
namespace at {
|
| 12 |
+
class Tensor;
|
| 13 |
+
struct TensorIterator;
|
| 14 |
+
|
| 15 |
+
namespace native {
|
| 16 |
+
|
| 17 |
+
DECLARE_DISPATCH(void(*)(TensorIterator&, const c10::Scalar&), fill_stub);
|
| 18 |
+
|
| 19 |
+
Tensor& fill_out(Tensor& self, const Scalar& value);
|
| 20 |
+
|
| 21 |
+
}} // namespace at::native
|
.venv/lib/python3.11/site-packages/torch/include/ATen/native/FractionalMaxPooling.h
ADDED
|
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
#include <ATen/core/Tensor.h>
|
| 3 |
+
#include <ATen/TensorUtils.h>
|
| 4 |
+
#include <c10/util/irange.h>
|
| 5 |
+
|
| 6 |
+
namespace at::native {
|
| 7 |
+
|
| 8 |
+
template<typename scalar_t>
|
| 9 |
+
inline std::vector<int64_t> generate_intervals(
|
| 10 |
+
scalar_t sample,
|
| 11 |
+
int64_t inputSize,
|
| 12 |
+
int64_t outputSize,
|
| 13 |
+
int64_t poolSize) {
|
| 14 |
+
std::vector<int64_t> sequence(outputSize);
|
| 15 |
+
if (outputSize > 1) {
|
| 16 |
+
scalar_t alpha = static_cast<scalar_t>(inputSize - poolSize) /
|
| 17 |
+
static_cast<scalar_t>(outputSize - 1);
|
| 18 |
+
|
| 19 |
+
for (const auto i : c10::irange(outputSize - 1)) {
|
| 20 |
+
sequence[i] =
|
| 21 |
+
static_cast<int>((i + sample) * alpha) - static_cast<int>(sample * alpha);
|
| 22 |
+
}
|
| 23 |
+
}
|
| 24 |
+
if (outputSize > 0) {
|
| 25 |
+
sequence[outputSize - 1] = inputSize - poolSize;
|
| 26 |
+
}
|
| 27 |
+
return sequence;
|
| 28 |
+
}
|
| 29 |
+
|
| 30 |
+
template <int64_t ndim>
|
| 31 |
+
inline void fractional_max_pool_check_shape(
|
| 32 |
+
const Tensor& input,
|
| 33 |
+
const Tensor& randomSamples) {
|
| 34 |
+
|
| 35 |
+
TORCH_CHECK(
|
| 36 |
+
input.scalar_type() == randomSamples.scalar_type(),
|
| 37 |
+
"Expect _random_samples to have the same dtype as input");
|
| 38 |
+
|
| 39 |
+
int64_t ndimension = randomSamples.ndimension();
|
| 40 |
+
TORCH_CHECK(
|
| 41 |
+
ndimension == 3,
|
| 42 |
+
"Expect _random_samples to have 3 dimensions, got ", ndimension);
|
| 43 |
+
|
| 44 |
+
int64_t N = randomSamples.size(0);
|
| 45 |
+
int64_t C = randomSamples.size(1);
|
| 46 |
+
int64_t D = randomSamples.size(2);
|
| 47 |
+
|
| 48 |
+
int64_t input_batch = 0, input_channel = 0;
|
| 49 |
+
if (ndim == 2) {
|
| 50 |
+
// fractional_max_pool2d
|
| 51 |
+
if (input.ndimension() == 3) {
|
| 52 |
+
input_batch = 1;
|
| 53 |
+
input_channel = input.size(0);
|
| 54 |
+
} else {
|
| 55 |
+
input_batch = input.size(0);
|
| 56 |
+
input_channel = input.size(1);
|
| 57 |
+
}
|
| 58 |
+
} else {
|
| 59 |
+
// factional_max_pool3d
|
| 60 |
+
if (input.ndimension() == 4) {
|
| 61 |
+
input_batch = 1;
|
| 62 |
+
input_channel = input.size(0);
|
| 63 |
+
} else {
|
| 64 |
+
input_batch = input.size(0);
|
| 65 |
+
input_channel = input.size(1);
|
| 66 |
+
}
|
| 67 |
+
}
|
| 68 |
+
|
| 69 |
+
TORCH_CHECK(
|
| 70 |
+
N >= input_batch,
|
| 71 |
+
"Expect _random_samples.size(0) no less then input batch size.");
|
| 72 |
+
TORCH_CHECK(
|
| 73 |
+
C == input_channel,
|
| 74 |
+
"Expect _random_samples.size(1) equals to input channel size.");
|
| 75 |
+
TORCH_CHECK(
|
| 76 |
+
D == ndim,
|
| 77 |
+
"Expect _random_samples.size(2) equals to ", ndim, "; got ", D, ".");
|
| 78 |
+
}
|
| 79 |
+
|
| 80 |
+
} // namespace at::native
|
.venv/lib/python3.11/site-packages/torch/include/ATen/native/FunctionOfAMatrixUtils.h
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/native/DispatchStub.h>
|
| 4 |
+
#include <cstdint>
|
| 5 |
+
|
| 6 |
+
namespace at {
|
| 7 |
+
struct TensorIterator;
|
| 8 |
+
|
| 9 |
+
namespace native {
|
| 10 |
+
|
| 11 |
+
using _compute_linear_combination_fn = void(*)(
|
| 12 |
+
TensorIterator& iter,
|
| 13 |
+
int64_t in_stride,
|
| 14 |
+
int64_t coeff_stride,
|
| 15 |
+
int64_t num_summations
|
| 16 |
+
);
|
| 17 |
+
|
| 18 |
+
DECLARE_DISPATCH(_compute_linear_combination_fn, _compute_linear_combination_stub);
|
| 19 |
+
|
| 20 |
+
}} // namespace at::native
|
.venv/lib/python3.11/site-packages/torch/include/ATen/native/GridSampler.h
ADDED
|
@@ -0,0 +1,298 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <algorithm>
|
| 4 |
+
#include <cmath>
|
| 5 |
+
#include <cstdint>
|
| 6 |
+
#include <utility>
|
| 7 |
+
|
| 8 |
+
#include <ATen/native/GridSamplerUtils.h>
|
| 9 |
+
|
| 10 |
+
namespace at::native {
|
| 11 |
+
|
| 12 |
+
using detail::GridSamplerInterpolation;
|
| 13 |
+
using detail::GridSamplerPadding;
|
| 14 |
+
|
| 15 |
+
// Unnormalizes a coordinate from the -1 to +1 scale to its pixel index value,
|
| 16 |
+
// where we view each pixel as an area between (idx - 0.5) and (idx + 0.5).
|
| 17 |
+
// if align_corners: -1 and +1 get sent to the centers of the corner pixels
|
| 18 |
+
// -1 --> 0
|
| 19 |
+
// +1 --> (size - 1)
|
| 20 |
+
// scale_factor = (size - 1) / 2
|
| 21 |
+
// if not align_corners: -1 and +1 get sent to the image edges
|
| 22 |
+
// -1 --> -0.5
|
| 23 |
+
// +1 --> (size - 1) + 0.5 == size - 0.5
|
| 24 |
+
// scale_factor = size / 2
|
| 25 |
+
template <typename scalar_t>
|
| 26 |
+
static inline scalar_t grid_sampler_unnormalize(scalar_t coord, int64_t size,
|
| 27 |
+
bool align_corners) {
|
| 28 |
+
if (align_corners) {
|
| 29 |
+
// unnormalize coord from [-1, 1] to [0, size - 1]
|
| 30 |
+
return ((coord + 1) / 2) * (size - 1);
|
| 31 |
+
} else {
|
| 32 |
+
// unnormalize coord from [-1, 1] to [-0.5, size - 0.5]
|
| 33 |
+
return ((coord + 1) * size - 1) / 2;
|
| 34 |
+
}
|
| 35 |
+
}
|
| 36 |
+
|
| 37 |
+
// grid_sampler_unnormalize_set_grad works the same as grid_sampler_unnormalize
|
| 38 |
+
// except that it also returns the `d output / d input` via pointer argument
|
| 39 |
+
// `grad_in`.
|
| 40 |
+
// This is useful in the backward pass of grid_sampler.
|
| 41 |
+
template <typename scalar_t>
|
| 42 |
+
static inline scalar_t grid_sampler_unnormalize_set_grad(scalar_t coord, int64_t size,
|
| 43 |
+
bool align_corners, scalar_t *grad_in) {
|
| 44 |
+
if (align_corners) {
|
| 45 |
+
// unnormalize coord from [-1, 1] to [0, size - 1]
|
| 46 |
+
*grad_in = static_cast<scalar_t>(size - 1) / 2;
|
| 47 |
+
return ((coord + 1) / 2) * (size - 1);
|
| 48 |
+
} else {
|
| 49 |
+
// unnormalize coord from [-1, 1] to [-0.5, size - 0.5]
|
| 50 |
+
*grad_in = static_cast<scalar_t>(size) / 2;
|
| 51 |
+
return ((coord + 1) * size - 1) / 2;
|
| 52 |
+
}
|
| 53 |
+
}
|
| 54 |
+
|
| 55 |
+
// Clips coordinates to between 0 and clip_limit - 1
|
| 56 |
+
template<typename scalar_t>
|
| 57 |
+
static inline scalar_t clip_coordinates(scalar_t in, int64_t clip_limit) {
|
| 58 |
+
return std::min(static_cast<scalar_t>(clip_limit - 1), std::max(in, static_cast<scalar_t>(0)));
|
| 59 |
+
}
|
| 60 |
+
|
| 61 |
+
// clip_coordinates_set_grad works similarly to clip_coordinates except that
|
| 62 |
+
// it also returns the `d output / d input` via pointer argument `grad_in`.
|
| 63 |
+
// This is useful in the backward pass of grid_sampler.
|
| 64 |
+
template<typename scalar_t>
|
| 65 |
+
static inline scalar_t clip_coordinates_set_grad(scalar_t in, int64_t clip_limit,
|
| 66 |
+
scalar_t *grad_in) {
|
| 67 |
+
// Note that it is important for the gradient calculation that borders
|
| 68 |
+
// are considered out of bounds.
|
| 69 |
+
if (in <= static_cast<scalar_t>(0)) {
|
| 70 |
+
*grad_in = static_cast<scalar_t>(0);
|
| 71 |
+
return static_cast<scalar_t>(0);
|
| 72 |
+
} else {
|
| 73 |
+
scalar_t max = static_cast<scalar_t>(clip_limit - 1);
|
| 74 |
+
if (in >= max) {
|
| 75 |
+
*grad_in = static_cast<scalar_t>(0);
|
| 76 |
+
return max;
|
| 77 |
+
} else {
|
| 78 |
+
*grad_in = static_cast<scalar_t>(1);
|
| 79 |
+
return in;
|
| 80 |
+
}
|
| 81 |
+
}
|
| 82 |
+
}
|
| 83 |
+
|
| 84 |
+
// Reflects coordinates until they fall between low and high (inclusive).
|
| 85 |
+
// The bounds are passed as twice their value so that half-integer values
|
| 86 |
+
// can be represented as ints.
|
| 87 |
+
template<typename scalar_t>
|
| 88 |
+
static inline scalar_t reflect_coordinates(scalar_t in, int64_t twice_low,
|
| 89 |
+
int64_t twice_high) {
|
| 90 |
+
if (twice_low == twice_high) {
|
| 91 |
+
return static_cast<scalar_t>(0);
|
| 92 |
+
}
|
| 93 |
+
scalar_t min = static_cast<scalar_t>(twice_low) / 2;
|
| 94 |
+
scalar_t span = static_cast<scalar_t>(twice_high - twice_low) / 2;
|
| 95 |
+
in = std::fabs(in - min);
|
| 96 |
+
// `fmod` returns same sign as `in`, which is positive after the `fabs` above.
|
| 97 |
+
scalar_t extra = std::fmod(in, span);
|
| 98 |
+
int flips = static_cast<int>(std::floor(in / span));
|
| 99 |
+
if (flips % 2 == 0) {
|
| 100 |
+
return extra + min;
|
| 101 |
+
} else {
|
| 102 |
+
return span - extra + min;
|
| 103 |
+
}
|
| 104 |
+
}
|
| 105 |
+
|
| 106 |
+
// reflect_coordinates_set_grad works similarly to reflect_coordinates except
|
| 107 |
+
// that it also returns the `d output / d input` via pointer argument
|
| 108 |
+
// `grad_in`.
|
| 109 |
+
// This is useful in the backward pass of grid_sampler.
|
| 110 |
+
template<typename scalar_t>
|
| 111 |
+
static inline scalar_t reflect_coordinates_set_grad(scalar_t in, int64_t twice_low,
|
| 112 |
+
int64_t twice_high, scalar_t *grad_in) {
|
| 113 |
+
if (twice_low == twice_high) {
|
| 114 |
+
*grad_in = static_cast<scalar_t>(0);
|
| 115 |
+
return static_cast<scalar_t>(0);
|
| 116 |
+
}
|
| 117 |
+
int grad_in_mult_;
|
| 118 |
+
scalar_t min = static_cast<scalar_t>(twice_low) / 2;
|
| 119 |
+
scalar_t span = static_cast<scalar_t>(twice_high - twice_low) / 2;
|
| 120 |
+
in = in - min;
|
| 121 |
+
if (in < static_cast<scalar_t>(0)) {
|
| 122 |
+
grad_in_mult_ = -1;
|
| 123 |
+
in = -in;
|
| 124 |
+
} else {
|
| 125 |
+
grad_in_mult_ = 1;
|
| 126 |
+
}
|
| 127 |
+
// `fmod` returns same sign as `in`, which is positive after the `if` above.
|
| 128 |
+
scalar_t extra = std::fmod(in, span);
|
| 129 |
+
int flips = static_cast<int>(std::floor(in / span));
|
| 130 |
+
if (flips % 2 == 0) {
|
| 131 |
+
*grad_in = static_cast<scalar_t>(grad_in_mult_);
|
| 132 |
+
return extra + min;
|
| 133 |
+
} else {
|
| 134 |
+
*grad_in = static_cast<scalar_t>(-grad_in_mult_);
|
| 135 |
+
return span - extra + min;
|
| 136 |
+
}
|
| 137 |
+
}
|
| 138 |
+
|
| 139 |
+
// Mapping the out-of-boundary points back into boundary
|
| 140 |
+
// This would only affect padding_mode=border or reflection
|
| 141 |
+
template<typename scalar_t>
|
| 142 |
+
static inline scalar_t compute_coordinates(scalar_t coord, int64_t size,
|
| 143 |
+
GridSamplerPadding padding_mode,
|
| 144 |
+
bool align_corners) {
|
| 145 |
+
if (padding_mode == GridSamplerPadding::Border) {
|
| 146 |
+
// clip coordinates to image borders
|
| 147 |
+
coord = clip_coordinates(coord, size);
|
| 148 |
+
} else if (padding_mode == GridSamplerPadding::Reflection) {
|
| 149 |
+
// reflect coordinates by image borders
|
| 150 |
+
if (align_corners) {
|
| 151 |
+
coord = reflect_coordinates(coord, 0, 2*(size - 1));
|
| 152 |
+
} else {
|
| 153 |
+
coord = reflect_coordinates(coord, -1, 2*size - 1);
|
| 154 |
+
}
|
| 155 |
+
// clip coordinates to image borders
|
| 156 |
+
coord = clip_coordinates(coord, size);
|
| 157 |
+
}
|
| 158 |
+
return coord;
|
| 159 |
+
}
|
| 160 |
+
|
| 161 |
+
// Computes the pixel source index value for a grid coordinate
|
| 162 |
+
template <typename scalar_t>
|
| 163 |
+
static inline scalar_t grid_sampler_compute_source_index(
|
| 164 |
+
scalar_t coord,
|
| 165 |
+
int64_t size,
|
| 166 |
+
GridSamplerPadding padding_mode,
|
| 167 |
+
bool align_corners) {
|
| 168 |
+
coord = grid_sampler_unnormalize(coord, size, align_corners);
|
| 169 |
+
coord = compute_coordinates(coord, size, padding_mode, align_corners);
|
| 170 |
+
return coord;
|
| 171 |
+
}
|
| 172 |
+
|
| 173 |
+
// grid_sampler_compute_source_index_set_grad works similarly to
|
| 174 |
+
// grid_sampler_compute_source_index except that it also returns the
|
| 175 |
+
// `d output / d input` via pointer argument `grad_in`.
|
| 176 |
+
// This is useful in the backward pass of grid_sampler.
|
| 177 |
+
template <typename scalar_t>
|
| 178 |
+
static inline scalar_t grid_sampler_compute_source_index_set_grad(
|
| 179 |
+
scalar_t coord,
|
| 180 |
+
int64_t size,
|
| 181 |
+
GridSamplerPadding padding_mode,
|
| 182 |
+
bool align_corners,
|
| 183 |
+
scalar_t *grad_in) {
|
| 184 |
+
scalar_t grad_clip, grad_refl;
|
| 185 |
+
coord = grid_sampler_unnormalize_set_grad(coord, size, align_corners, grad_in);
|
| 186 |
+
if (padding_mode == GridSamplerPadding::Border) {
|
| 187 |
+
// clip coordinates to image borders
|
| 188 |
+
coord = clip_coordinates_set_grad(coord, size, &grad_clip);
|
| 189 |
+
*grad_in = (*grad_in) * grad_clip;
|
| 190 |
+
} else if (padding_mode == GridSamplerPadding::Reflection) {
|
| 191 |
+
// reflect coordinates by image borders
|
| 192 |
+
if (align_corners) {
|
| 193 |
+
coord = reflect_coordinates_set_grad(coord, 0, 2*(size - 1), &grad_refl);
|
| 194 |
+
} else {
|
| 195 |
+
coord = reflect_coordinates_set_grad(coord, -1, 2*size - 1, &grad_refl);
|
| 196 |
+
}
|
| 197 |
+
// clip coordinates to image borders
|
| 198 |
+
coord = clip_coordinates_set_grad(coord, size, &grad_clip);
|
| 199 |
+
*grad_in = (*grad_in) * grad_refl * grad_clip;
|
| 200 |
+
}
|
| 201 |
+
return coord;
|
| 202 |
+
}
|
| 203 |
+
|
| 204 |
+
static inline bool within_bounds_2d(int64_t h, int64_t w, int64_t H, int64_t W) {
|
| 205 |
+
return h >= 0 && h < H && w >= 0 && w < W;
|
| 206 |
+
}
|
| 207 |
+
|
| 208 |
+
static inline bool within_bounds_3d(int64_t d, int64_t h, int64_t w, int64_t D, int64_t H, int64_t W) {
|
| 209 |
+
return d >= 0 && d < D && h >= 0 && h < H && w >= 0 && w < W;
|
| 210 |
+
}
|
| 211 |
+
|
| 212 |
+
template<typename scalar_t>
|
| 213 |
+
static inline scalar_t get_value_bounded(
|
| 214 |
+
const scalar_t* data,
|
| 215 |
+
scalar_t x,
|
| 216 |
+
scalar_t y,
|
| 217 |
+
int64_t W,
|
| 218 |
+
int64_t H,
|
| 219 |
+
int64_t sW,
|
| 220 |
+
int64_t sH,
|
| 221 |
+
GridSamplerPadding padding_mode,
|
| 222 |
+
bool align_corners) {
|
| 223 |
+
|
| 224 |
+
x = compute_coordinates(x, W, padding_mode, align_corners);
|
| 225 |
+
y = compute_coordinates(y, H, padding_mode, align_corners);
|
| 226 |
+
|
| 227 |
+
int64_t ix = static_cast<int64_t>(x);
|
| 228 |
+
int64_t iy = static_cast<int64_t>(y);
|
| 229 |
+
|
| 230 |
+
if (within_bounds_2d(iy, ix, H, W)) {
|
| 231 |
+
return data[iy * sH + ix * sW];
|
| 232 |
+
}
|
| 233 |
+
return static_cast<scalar_t>(0);
|
| 234 |
+
}
|
| 235 |
+
|
| 236 |
+
template<typename scalar_t>
|
| 237 |
+
static inline void safe_add_2d(scalar_t *data, int64_t h, int64_t w,
|
| 238 |
+
int64_t sH, int64_t sW, int64_t H, int64_t W,
|
| 239 |
+
scalar_t delta) {
|
| 240 |
+
if (within_bounds_2d(h, w, H, W)) {
|
| 241 |
+
data[h * sH + w * sW] += delta;
|
| 242 |
+
}
|
| 243 |
+
}
|
| 244 |
+
|
| 245 |
+
template<typename scalar_t>
|
| 246 |
+
static inline void safe_add_3d(scalar_t *data, int64_t d, int64_t h, int64_t w,
|
| 247 |
+
int64_t sD, int64_t sH, int64_t sW,
|
| 248 |
+
int64_t D, int64_t H, int64_t W,
|
| 249 |
+
scalar_t delta) {
|
| 250 |
+
if (within_bounds_3d(d, h, w, D, H, W)) {
|
| 251 |
+
data[d * sD + h * sH + w * sW] += delta;
|
| 252 |
+
}
|
| 253 |
+
}
|
| 254 |
+
|
| 255 |
+
template<typename scalar_t>
|
| 256 |
+
static inline void add_value_bounded(
|
| 257 |
+
scalar_t* data,
|
| 258 |
+
scalar_t x,
|
| 259 |
+
scalar_t y,
|
| 260 |
+
int64_t W,
|
| 261 |
+
int64_t H,
|
| 262 |
+
int64_t sW,
|
| 263 |
+
int64_t sH,
|
| 264 |
+
scalar_t delta,
|
| 265 |
+
GridSamplerPadding padding_mode,
|
| 266 |
+
bool align_corners) {
|
| 267 |
+
|
| 268 |
+
x = compute_coordinates(x, W, padding_mode, align_corners);
|
| 269 |
+
y = compute_coordinates(y, H, padding_mode, align_corners);
|
| 270 |
+
|
| 271 |
+
int64_t ix = static_cast<int64_t>(x);
|
| 272 |
+
int64_t iy = static_cast<int64_t>(y);
|
| 273 |
+
|
| 274 |
+
safe_add_2d(data, iy, ix, sH, sW, H, W, delta);
|
| 275 |
+
}
|
| 276 |
+
|
| 277 |
+
// Calculate the differential of the cubic convolution, i.e. `d coeff / d x`
|
| 278 |
+
template<typename scalar_t>
|
| 279 |
+
static inline void get_cubic_coefficients_grad(
|
| 280 |
+
scalar_t coeffs[4],
|
| 281 |
+
scalar_t t) {
|
| 282 |
+
|
| 283 |
+
// Must be the same as forward calculation in
|
| 284 |
+
// aten/src/ATen/native/UpSample.h:get_cubic_upsample_coefficients
|
| 285 |
+
scalar_t A = -0.75;
|
| 286 |
+
|
| 287 |
+
scalar_t x;
|
| 288 |
+
x = -1 - t; // 1 < x = |-1 - tx| < 2
|
| 289 |
+
coeffs[0] = (-3 * A * x - 10 * A ) * x - 8 * A;
|
| 290 |
+
x = -t; // x = |0 - tx| <= 1
|
| 291 |
+
coeffs[1] = (-3 * (A + 2) * x - 2 * (A + 3)) * x;
|
| 292 |
+
x = 1 - t; // x = |1 - tx| <= 1
|
| 293 |
+
coeffs[2] = (3 * (A + 2) * x - 2 * (A + 3)) * x;
|
| 294 |
+
x = 2 - t; // 1 < x = |2 - tx| < 2
|
| 295 |
+
coeffs[3] = (3 * A * x - 10 * A) * x + 8 * A;
|
| 296 |
+
}
|
| 297 |
+
|
| 298 |
+
} // namespace at::native
|
.venv/lib/python3.11/site-packages/torch/include/ATen/native/LossMulti.h
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
#include <ATen/core/Tensor.h>
|
| 3 |
+
#include <ATen/AccumulateType.h>
|
| 4 |
+
#include <ATen/Dispatch.h>
|
| 5 |
+
#include <ATen/TensorUtils.h>
|
| 6 |
+
|
| 7 |
+
namespace at::native {
|
| 8 |
+
inline void multilabel_margin_loss_shape_check(
|
| 9 |
+
int64_t& nframe,
|
| 10 |
+
int64_t& dim,
|
| 11 |
+
const int64_t& ndims,
|
| 12 |
+
const Tensor& input,
|
| 13 |
+
const Tensor& target) {
|
| 14 |
+
TORCH_CHECK(
|
| 15 |
+
(ndims == 2 && input.size(1) != 0) || (ndims == 1 && input.size(0) != 0) || ndims == 0,
|
| 16 |
+
"Expected non-empty vector or matrix with optional 0-dim batch size, but got: ",
|
| 17 |
+
input.sizes());
|
| 18 |
+
|
| 19 |
+
if (ndims <= 1) {
|
| 20 |
+
nframe = 1;
|
| 21 |
+
dim = ndims == 0 ? 1 : input.size(0);
|
| 22 |
+
TORCH_CHECK(
|
| 23 |
+
target.dim() <= 1 && target.numel() == dim,
|
| 24 |
+
"inconsistent target size: ", target.sizes(), " for input of size: ",
|
| 25 |
+
input.sizes());
|
| 26 |
+
} else {
|
| 27 |
+
nframe = input.size(0);
|
| 28 |
+
dim = input.size(1);
|
| 29 |
+
TORCH_CHECK(
|
| 30 |
+
target.dim() == 2 && target.size(0) == nframe &&
|
| 31 |
+
target.size(1) == dim,
|
| 32 |
+
"inconsistent target size: ", target.sizes(), " for input of size: ",
|
| 33 |
+
input.sizes());
|
| 34 |
+
}
|
| 35 |
+
}
|
| 36 |
+
|
| 37 |
+
inline void multi_margin_loss_shape_check(
|
| 38 |
+
int64_t& nframe,
|
| 39 |
+
int64_t& dim,
|
| 40 |
+
const int64_t& ndims,
|
| 41 |
+
const Tensor& input,
|
| 42 |
+
const Tensor& target,
|
| 43 |
+
const std::optional<Tensor>& weight) {
|
| 44 |
+
TORCH_CHECK(
|
| 45 |
+
(ndims == 2 && input.size(1) != 0) || (ndims == 1 && input.size(0) != 0) || ndims == 0,
|
| 46 |
+
"Expected non-empty vector or matrix with optional 0-dim batch size, but got: ",
|
| 47 |
+
input.sizes());
|
| 48 |
+
|
| 49 |
+
if (ndims <= 1) {
|
| 50 |
+
nframe = 1;
|
| 51 |
+
dim = ndims == 0 ? 1 : input.size(0);
|
| 52 |
+
} else {
|
| 53 |
+
nframe = input.size(0);
|
| 54 |
+
dim = input.size(1);
|
| 55 |
+
}
|
| 56 |
+
|
| 57 |
+
TORCH_CHECK(
|
| 58 |
+
target.dim() <= 1 && target.numel() == nframe,
|
| 59 |
+
"inconsistent target size, expected ", nframe, " but got ",
|
| 60 |
+
target.sizes());
|
| 61 |
+
if (weight && weight->defined()) {
|
| 62 |
+
TORCH_CHECK(
|
| 63 |
+
weight->dim() <= 1 && weight->numel() == dim,
|
| 64 |
+
"inconsistent weight size, expected ", dim, " but got ",
|
| 65 |
+
weight->sizes());
|
| 66 |
+
}
|
| 67 |
+
}
|
| 68 |
+
|
| 69 |
+
} // namespace at::native
|
.venv/lib/python3.11/site-packages/torch/include/ATen/native/Math.h
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
.venv/lib/python3.11/site-packages/torch/include/ATen/native/MathBitFallThroughLists.h
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
namespace at {
|
| 4 |
+
// views and their in-place version ops
|
| 5 |
+
#define TORCH_VIEW_FNS(m) \
|
| 6 |
+
m.impl("as_strided_", torch::CppFunction::makeFallthrough()); \
|
| 7 |
+
m.impl("detach", torch::CppFunction::makeFallthrough()); \
|
| 8 |
+
m.impl("detach_", torch::CppFunction::makeFallthrough()); \
|
| 9 |
+
m.impl("diagonal", torch::CppFunction::makeFallthrough()); \
|
| 10 |
+
m.impl("expand", torch::CppFunction::makeFallthrough()); \
|
| 11 |
+
m.impl("expand_as", torch::CppFunction::makeFallthrough()); \
|
| 12 |
+
m.impl("movedim.int", torch::CppFunction::makeFallthrough()); \
|
| 13 |
+
m.impl("movedim.intlist", torch::CppFunction::makeFallthrough()); \
|
| 14 |
+
m.impl("narrow", torch::CppFunction::makeFallthrough()); \
|
| 15 |
+
m.impl("permute", torch::CppFunction::makeFallthrough()); \
|
| 16 |
+
m.impl("select.Dimname", torch::CppFunction::makeFallthrough()); \
|
| 17 |
+
m.impl("select.int", torch::CppFunction::makeFallthrough()); \
|
| 18 |
+
m.impl("squeeze", torch::CppFunction::makeFallthrough()); \
|
| 19 |
+
m.impl("squeeze_", torch::CppFunction::makeFallthrough()); \
|
| 20 |
+
m.impl("transpose.int", torch::CppFunction::makeFallthrough()); \
|
| 21 |
+
m.impl("transpose.Dimname", torch::CppFunction::makeFallthrough()); \
|
| 22 |
+
m.impl("transpose_", torch::CppFunction::makeFallthrough()); \
|
| 23 |
+
m.impl("t", torch::CppFunction::makeFallthrough()); \
|
| 24 |
+
m.impl("t_", torch::CppFunction::makeFallthrough()); \
|
| 25 |
+
m.impl("real", torch::CppFunction::makeFallthrough()); \
|
| 26 |
+
m.impl("imag", torch::CppFunction::makeFallthrough()); \
|
| 27 |
+
m.impl("view_as_real", torch::CppFunction::makeFallthrough()); \
|
| 28 |
+
m.impl("unflatten.int", torch::CppFunction::makeFallthrough()); \
|
| 29 |
+
m.impl("unflatten.Dimname", torch::CppFunction::makeFallthrough()); \
|
| 30 |
+
m.impl("unfold", torch::CppFunction::makeFallthrough()); \
|
| 31 |
+
m.impl("unsqueeze", torch::CppFunction::makeFallthrough()); \
|
| 32 |
+
m.impl("unsqueeze_", torch::CppFunction::makeFallthrough()); \
|
| 33 |
+
m.impl("view_as", torch::CppFunction::makeFallthrough()); \
|
| 34 |
+
m.impl("unbind.int", torch::CppFunction::makeFallthrough()); \
|
| 35 |
+
m.impl("unbind.Dimname", torch::CppFunction::makeFallthrough()); \
|
| 36 |
+
m.impl("split.Tensor", torch::CppFunction::makeFallthrough()); \
|
| 37 |
+
m.impl("split_with_sizes", torch::CppFunction::makeFallthrough()); \
|
| 38 |
+
m.impl("swapaxes", torch::CppFunction::makeFallthrough()); \
|
| 39 |
+
m.impl("swapdims", torch::CppFunction::makeFallthrough()); \
|
| 40 |
+
m.impl("chunk", torch::CppFunction::makeFallthrough()); \
|
| 41 |
+
m.impl("reshape", torch::CppFunction::makeFallthrough()); \
|
| 42 |
+
m.impl("alias", torch::CppFunction::makeFallthrough()); \
|
| 43 |
+
m.impl("hsplit.int", torch::CppFunction::makeFallthrough()); \
|
| 44 |
+
m.impl("hsplit.array", torch::CppFunction::makeFallthrough()); \
|
| 45 |
+
m.impl("dsplit.int", torch::CppFunction::makeFallthrough()); \
|
| 46 |
+
m.impl("dsplit.array", torch::CppFunction::makeFallthrough()); \
|
| 47 |
+
m.impl("vsplit.int", torch::CppFunction::makeFallthrough()); \
|
| 48 |
+
m.impl("vsplit.array", torch::CppFunction::makeFallthrough()); \
|
| 49 |
+
m.impl("conj", torch::CppFunction::makeFallthrough()); \
|
| 50 |
+
m.impl("_conj", torch::CppFunction::makeFallthrough()); \
|
| 51 |
+
m.impl("_unsafe_view", torch::CppFunction::makeFallthrough()); \
|
| 52 |
+
m.impl("resize_", torch::CppFunction::makeFallthrough());
|
| 53 |
+
|
| 54 |
+
#define TENSOR_UTILITIES_AND_CONSTRUCTORS(m) \
|
| 55 |
+
m.impl("empty_like", torch::CppFunction::makeFallthrough()); \
|
| 56 |
+
m.impl("empty.memory_format", torch::CppFunction::makeFallthrough()); \
|
| 57 |
+
m.impl("empty.out", torch::CppFunction::makeFallthrough()); \
|
| 58 |
+
m.impl("empty_strided", torch::CppFunction::makeFallthrough()); \
|
| 59 |
+
m.impl("full_like", torch::CppFunction::makeFallthrough()); \
|
| 60 |
+
m.impl("stride.int", torch::CppFunction::makeFallthrough()); \
|
| 61 |
+
m.impl("stride.Dimname", torch::CppFunction::makeFallthrough()); \
|
| 62 |
+
m.impl("size.int", torch::CppFunction::makeFallthrough()); \
|
| 63 |
+
m.impl("size.Dimname", torch::CppFunction::makeFallthrough()); \
|
| 64 |
+
m.impl("is_complex", torch::CppFunction::makeFallthrough()); \
|
| 65 |
+
m.impl("is_floating_point", torch::CppFunction::makeFallthrough()); \
|
| 66 |
+
m.impl("requires_grad_", torch::CppFunction::makeFallthrough());
|
| 67 |
+
}
|
| 68 |
+
|
| 69 |
+
#define TORCH_VIEW_FNS_NATIVE_FN_REGISTRATION(m) \
|
| 70 |
+
m.impl("as_strided", torch::CppFunction::makeFallthrough()); \
|
| 71 |
+
m.impl("view", torch::CppFunction::makeFallthrough());
|
.venv/lib/python3.11/site-packages/torch/include/ATen/native/NonEmptyUtils.h
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <ATen/core/TensorBase.h>
|
| 2 |
+
#include <algorithm>
|
| 3 |
+
#include <vector>
|
| 4 |
+
|
| 5 |
+
namespace at::native {
|
| 6 |
+
|
| 7 |
+
inline int64_t ensure_nonempty_dim(int64_t dim) {
|
| 8 |
+
return std::max<int64_t>(dim, 1);
|
| 9 |
+
}
|
| 10 |
+
|
| 11 |
+
inline int64_t ensure_nonempty_size(const TensorBase &t, int64_t dim) {
|
| 12 |
+
return t.dim() == 0 ? 1 : t.size(dim);
|
| 13 |
+
}
|
| 14 |
+
|
| 15 |
+
inline int64_t ensure_nonempty_stride(const TensorBase &t, int64_t dim) {
|
| 16 |
+
return t.dim() == 0 ? 1 : t.stride(dim);
|
| 17 |
+
}
|
| 18 |
+
|
| 19 |
+
using IdxVec = std::vector<int64_t>;
|
| 20 |
+
inline IdxVec ensure_nonempty_vec(IdxVec vec) {
|
| 21 |
+
if (vec.empty()) {
|
| 22 |
+
vec.push_back(1);
|
| 23 |
+
}
|
| 24 |
+
return vec;
|
| 25 |
+
}
|
| 26 |
+
|
| 27 |
+
} // namespace at::native
|
.venv/lib/python3.11/site-packages/torch/include/ATen/native/Normalization.h
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/TensorIterator.h>
|
| 4 |
+
#include <ATen/native/DispatchStub.h>
|
| 5 |
+
|
| 6 |
+
namespace at::native {
|
| 7 |
+
|
| 8 |
+
using renorm_scale_factor_fn = void (*) (TensorIteratorBase& iter, double maxnorm);
|
| 9 |
+
DECLARE_DISPATCH(renorm_scale_factor_fn, renorm_scale_factor_stub);
|
| 10 |
+
|
| 11 |
+
enum class BatchNormBackend {
|
| 12 |
+
Native,
|
| 13 |
+
Cudnn,
|
| 14 |
+
Miopen,
|
| 15 |
+
};
|
| 16 |
+
|
| 17 |
+
TORCH_API BatchNormBackend _select_batch_norm_backend(const Tensor& input, const Tensor& weight, const Tensor& bias, const Tensor& running_mean, const Tensor& running_var, bool training, double eps);
|
| 18 |
+
|
| 19 |
+
} // namespace at::native
|