Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .venv/lib/python3.11/site-packages/torch/include/ATen/native/AmpKernels.h +28 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/native/BucketizationUtils.h +173 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/native/CPUBlas.h +226 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/native/CPUFallback.h +46 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/native/CanUse32BitIndexMath.h +13 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/native/CompositeRandomAccessor.h +34 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/native/CompositeRandomAccessorCommon.h +263 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/native/ConvUtils.h +449 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/native/ConvolutionMM3d.h +14 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/native/Copy.h +20 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/native/Cross.h +14 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/native/DilatedConvolutionUtils.h +229 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/native/DispatchStub.h +444 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/native/DistributionTemplates.h +394 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/native/Distributions.h +518 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/native/EmbeddingBag.h +153 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/native/ForeachUtils.h +396 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/native/FusedAdagrad.h +20 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/native/FusedAdam.h +27 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/native/FusedSGD.h +21 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/native/GridSamplerUtils.h +105 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/native/Histogram.h +16 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/native/IndexKernel.h +41 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/native/IndexingUtils.h +160 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/native/Lerp.h +46 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/native/LinearAlgebra.h +17 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/native/LinearAlgebraUtils.h +623 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/native/MathBitsFallback.h +157 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/native/MaxPooling.h +97 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/native/NonSymbolicBC.h +26 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/native/RangeFactories.h +12 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/native/ReduceAllOps.h +16 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/native/ReductionType.h +40 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/native/Repeat.h +48 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/native/Resize.h +173 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/native/ResizeCommon.h +75 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/native/ScatterGatherChecks.h +128 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/native/SharedReduceOps.h +544 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/native/Sorting.h +28 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/native/SortingUtils.h +88 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/native/SparseTensorUtils.h +190 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/native/TensorAdvancedIndexing.h +49 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/native/TensorAdvancedIndexingUtils.h +94 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/native/TensorDimApply.h +55 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/native/TensorFactories.h +142 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/native/TensorIteratorDynamicCasting.h +52 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/native/TensorShape.h +105 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/native/TriangularOpsUtils.h +57 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/native/UnaryOps.h +130 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/native/Unfold2d.h +48 -0
.venv/lib/python3.11/site-packages/torch/include/ATen/native/AmpKernels.h
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/native/DispatchStub.h>
|
| 4 |
+
#include <ATen/core/ATen_fwd.h>
|
| 5 |
+
|
| 6 |
+
namespace at {
|
| 7 |
+
class Tensor;
|
| 8 |
+
|
| 9 |
+
namespace native {
|
| 10 |
+
|
| 11 |
+
using _amp_foreach_non_finite_check_and_unscale_cpu__fn = void (*)(
|
| 12 |
+
TensorList,
|
| 13 |
+
Tensor&,
|
| 14 |
+
const Tensor&);
|
| 15 |
+
|
| 16 |
+
using _amp_update_scale_cpu__fn = Tensor& (*)(
|
| 17 |
+
Tensor&,
|
| 18 |
+
Tensor&,
|
| 19 |
+
const Tensor&,
|
| 20 |
+
double,
|
| 21 |
+
double,
|
| 22 |
+
int64_t);
|
| 23 |
+
|
| 24 |
+
DECLARE_DISPATCH(_amp_foreach_non_finite_check_and_unscale_cpu__fn, _amp_foreach_non_finite_check_and_unscale_cpu_stub);
|
| 25 |
+
DECLARE_DISPATCH(_amp_update_scale_cpu__fn, _amp_update_scale_cpu_stub);
|
| 26 |
+
|
| 27 |
+
} // namespace native
|
| 28 |
+
} // namespace at
|
.venv/lib/python3.11/site-packages/torch/include/ATen/native/BucketizationUtils.h
ADDED
|
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/core/Tensor.h>
|
| 4 |
+
#include <ATen/native/TypeProperties.h>
|
| 5 |
+
#include <ATen/ScalarOps.h>
|
| 6 |
+
|
| 7 |
+
#ifndef AT_PER_OPERATOR_HEADERS
|
| 8 |
+
#include <ATen/NativeFunctions.h>
|
| 9 |
+
#else
|
| 10 |
+
#include <ATen/ops/result_type.h>
|
| 11 |
+
#endif
|
| 12 |
+
|
| 13 |
+
namespace at::native {
|
| 14 |
+
|
| 15 |
+
// original values given by raw_*. If an original value is not contiguous, will make a contiguous copy to
|
| 16 |
+
// the corresponding trimmed_* value. Additionally, if the dtypes of the boundary and input tensor do not
|
| 17 |
+
// match, will change them to be a common super type so comparisons are done between the same types.
|
| 18 |
+
// For any trimmed_* tensor, if its outgoing value matches what it was incoming (typically null), then the
|
| 19 |
+
// corresponding raw_* version should be used since it was already contiguous of the right type.
|
| 20 |
+
inline void searchsorted_maybe_trim_input_tensors(
|
| 21 |
+
Tensor& trimmed_input,
|
| 22 |
+
Tensor& trimmed_boundaries,
|
| 23 |
+
Tensor& trimmed_sorter,
|
| 24 |
+
const Tensor& raw_input,
|
| 25 |
+
const Tensor& raw_boundaries,
|
| 26 |
+
const Tensor& raw_sorter) {
|
| 27 |
+
bool in_is_contiguous = raw_input.is_contiguous();
|
| 28 |
+
bool bd_is_contiguous = raw_boundaries.is_contiguous();
|
| 29 |
+
bool sort_is_contiguous = raw_sorter.is_contiguous();
|
| 30 |
+
|
| 31 |
+
if (!in_is_contiguous) {
|
| 32 |
+
TORCH_WARN_ONCE("torch.searchsorted(): input value tensor is non-contiguous, this will lower the performance due "
|
| 33 |
+
"to extra data copy when converting non-contiguous tensor to contiguous, please use contiguous input value "
|
| 34 |
+
"tensor if possible. This message will only appear once per program.");
|
| 35 |
+
trimmed_input = raw_input.contiguous();
|
| 36 |
+
}
|
| 37 |
+
if (!bd_is_contiguous) {
|
| 38 |
+
TORCH_WARN_ONCE("torch.searchsorted(): boundary tensor is non-contiguous, this will lower the performance due "
|
| 39 |
+
"to extra data copy when converting non-contiguous tensor to contiguous, please use contiguous boundary "
|
| 40 |
+
"tensor if possible. This message will only appear once per program.");
|
| 41 |
+
trimmed_boundaries = raw_boundaries.contiguous();
|
| 42 |
+
}
|
| 43 |
+
if (!sort_is_contiguous) {
|
| 44 |
+
TORCH_WARN_ONCE("torch.searchsorted(): sorter tensor is non-contiguous, this will lower the performance due "
|
| 45 |
+
"to extra data copy when converting non-contiguous tensor to contiguous, please use contiguous sorter "
|
| 46 |
+
"tensor if possible. This message will only appear once per program.");
|
| 47 |
+
trimmed_sorter = raw_sorter.contiguous();
|
| 48 |
+
}
|
| 49 |
+
if (raw_input.dtype() != raw_boundaries.dtype()) {
|
| 50 |
+
at::native::ResultTypeState state = {};
|
| 51 |
+
state = at::native::update_result_type_state(raw_boundaries, state);
|
| 52 |
+
state = at::native::update_result_type_state(raw_input, state);
|
| 53 |
+
ScalarType common_stype = at::native::result_type(state);
|
| 54 |
+
|
| 55 |
+
TORCH_INTERNAL_ASSERT(common_stype != ScalarType::Undefined);
|
| 56 |
+
if (common_stype != raw_input.scalar_type()) {
|
| 57 |
+
trimmed_input = in_is_contiguous ? raw_input.to(common_stype) : trimmed_input.to(common_stype);
|
| 58 |
+
}
|
| 59 |
+
if (common_stype != raw_boundaries.scalar_type()) {
|
| 60 |
+
trimmed_boundaries = bd_is_contiguous ? raw_boundaries.to(common_stype) : trimmed_boundaries.to(common_stype);
|
| 61 |
+
}
|
| 62 |
+
}
|
| 63 |
+
}
|
| 64 |
+
|
| 65 |
+
/* unused but needed for internal jagged tensor class */
|
| 66 |
+
inline void searchsorted_maybe_trim_input_tensors(
|
| 67 |
+
Tensor& trimmed_input,
|
| 68 |
+
Tensor& trimmed_boundaries,
|
| 69 |
+
const Tensor& raw_input,
|
| 70 |
+
const Tensor& raw_boundaries) {
|
| 71 |
+
Tensor trimmed_sorter;
|
| 72 |
+
Tensor raw_sorter;
|
| 73 |
+
return searchsorted_maybe_trim_input_tensors(
|
| 74 |
+
trimmed_input,
|
| 75 |
+
trimmed_boundaries,
|
| 76 |
+
trimmed_sorter,
|
| 77 |
+
raw_input,
|
| 78 |
+
raw_boundaries,
|
| 79 |
+
raw_sorter);
|
| 80 |
+
}
|
| 81 |
+
|
| 82 |
+
inline bool searchsorted_dims_matched_before_last_dim(const Tensor& boundaries, const Tensor& input) {
|
| 83 |
+
if (boundaries.dim() != input.dim()) {
|
| 84 |
+
return false;
|
| 85 |
+
}
|
| 86 |
+
const auto& dims_bd = boundaries.sizes();
|
| 87 |
+
const auto& dims_in = input.sizes();
|
| 88 |
+
for (int64_t dim = 0; dim + 1 < boundaries.dim(); ++dim) {
|
| 89 |
+
if (dims_bd[dim] != dims_in[dim]) {
|
| 90 |
+
return false;
|
| 91 |
+
}
|
| 92 |
+
}
|
| 93 |
+
return true;
|
| 94 |
+
}
|
| 95 |
+
|
| 96 |
+
inline Tensor searchsorted_scalar_tensor(const Scalar& scalar, const c10::Device& device) {
|
| 97 |
+
auto tensor = c10::scalar_to_tensor(scalar, device);
|
| 98 |
+
// This is to adopt the scalar promotion rules defined in native/TypeProperties.h
|
| 99 |
+
// So we have the same type promotion rules as binary operations.
|
| 100 |
+
tensor.unsafeGetTensorImpl()->set_wrapped_number(true);
|
| 101 |
+
return tensor;
|
| 102 |
+
}
|
| 103 |
+
|
| 104 |
+
inline void searchsorted_pre_check(
|
| 105 |
+
const Tensor& boundaries,
|
| 106 |
+
const Tensor& input,
|
| 107 |
+
const Tensor& output,
|
| 108 |
+
const bool out_int32,
|
| 109 |
+
const bool right,
|
| 110 |
+
const std::optional<c10::string_view> side_opt,
|
| 111 |
+
const Tensor& sorter) {
|
| 112 |
+
if (side_opt) {
|
| 113 |
+
const c10::string_view side = *side_opt;
|
| 114 |
+
TORCH_CHECK(side == "left" || side == "right", "torch.searchsorted(): side can only be 'left' or 'right' but ",
|
| 115 |
+
"got ", side);
|
| 116 |
+
|
| 117 |
+
// assume the user has not explicitly set (right=False, side="right")
|
| 118 |
+
TORCH_CHECK(!right || side == "right", "torch.searchsorted(): side and right can't be set to opposites, got side "
|
| 119 |
+
"of ", side, " while right was True");
|
| 120 |
+
}
|
| 121 |
+
|
| 122 |
+
TORCH_CHECK(boundaries.device() == input.device(), "torch.searchsorted(): boundaries and input value tensors ",
|
| 123 |
+
"should have same device type, but got boundaries tensor device type ", boundaries.device(), " and input value ",
|
| 124 |
+
"tensor device type ", input.device());
|
| 125 |
+
|
| 126 |
+
if (sorter.defined()) {
|
| 127 |
+
TORCH_CHECK(sorter.device() == boundaries.device(), "torch.searchsorted(): sorter and boundary tensors should ",
|
| 128 |
+
"have same device type, but got sorter tensor device type ", sorter.device(), " and input value tensor ",
|
| 129 |
+
"device type ", boundaries.device());
|
| 130 |
+
|
| 131 |
+
TORCH_CHECK(sorter.sizes() == boundaries.sizes(), "torch.searchsorted(): boundary and sorter must have the same "
|
| 132 |
+
"size, but got boundary tensor ", boundaries.sizes(), "and got sorter tensor ", sorter.sizes());
|
| 133 |
+
|
| 134 |
+
TORCH_CHECK(sorter.scalar_type() == ScalarType::Long, "torch.searchsorted(): sorter must be a tensor of long ",
|
| 135 |
+
"dtype but got dtype ", sorter.scalar_type());
|
| 136 |
+
|
| 137 |
+
if (sorter.numel() > 0) {
|
| 138 |
+
auto minmax = sorter.aminmax();
|
| 139 |
+
int64_t vmin = std::get<0>(minmax).item().toLong();
|
| 140 |
+
int64_t vmax = std::get<1>(minmax).item().toLong();
|
| 141 |
+
TORCH_CHECK(vmin >= 0 && vmax < sorter.sizes().back(), "torch.searchsorted(): sorter index out of range");
|
| 142 |
+
}
|
| 143 |
+
}
|
| 144 |
+
|
| 145 |
+
TORCH_CHECK(input.dim() > 0 || (input.dim() == 0 && input.numel() == 1 && boundaries.dim() == 1),
|
| 146 |
+
"torch.searchsorted(): input value can be a scalar only when boundaries tensor dimension is 1, but we got ",
|
| 147 |
+
"boundaries tensor dim(", boundaries.dim(), ") and input value's dim(", input.dim(), ") numel(",
|
| 148 |
+
input.numel(), ")");
|
| 149 |
+
|
| 150 |
+
TORCH_CHECK(boundaries.dim() != 0, "torch.searchsorted(): boundaries tensor should have positive dimension, but ",
|
| 151 |
+
"got 0 dimension");
|
| 152 |
+
|
| 153 |
+
TORCH_CHECK(boundaries.dim() == 1 || searchsorted_dims_matched_before_last_dim(boundaries, input),
|
| 154 |
+
"torch.searchsorted(): boundaries tensor should be 1 dimension or the first N-1 dimensions of boundaries tensor ",
|
| 155 |
+
"and input value tensor must match, but we got boundaries tensor ", boundaries.sizes(), " and input value tensor ",
|
| 156 |
+
input.sizes());
|
| 157 |
+
|
| 158 |
+
ScalarType output_dtype = output.scalar_type();
|
| 159 |
+
TORCH_CHECK(
|
| 160 |
+
(output_dtype == ScalarType::Long && !out_int32) ||
|
| 161 |
+
(output_dtype == ScalarType::Int && out_int32),
|
| 162 |
+
"torch.searchsorted(): output tensor's dtype is wrong, it can only be Int(int32) or Long(int64) depending on ",
|
| 163 |
+
"whether out_int32 flag is True, but we got output tensor's dtype ", output_dtype,
|
| 164 |
+
" and out_int32 flag is ", (out_int32 ? "True" : "False"));
|
| 165 |
+
|
| 166 |
+
if (out_int32) {
|
| 167 |
+
TORCH_CHECK(boundaries.sizes().back() < INT_MAX,
|
| 168 |
+
"torch.searchsorted(): the size of boundaries' last dimension should be less than ", INT_MAX, ", but we got ",
|
| 169 |
+
boundaries.sizes().back());
|
| 170 |
+
}
|
| 171 |
+
}
|
| 172 |
+
|
| 173 |
+
} // namespace at::native
|
.venv/lib/python3.11/site-packages/torch/include/ATen/native/CPUBlas.h
ADDED
|
@@ -0,0 +1,226 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/OpMathType.h>
|
| 4 |
+
#include <ATen/native/DispatchStub.h>
|
| 5 |
+
#include <ATen/native/TransposeType.h>
|
| 6 |
+
#include <c10/util/complex.h>
|
| 7 |
+
#include <c10/core/ScalarType.h>
|
| 8 |
+
#include <c10/core/Scalar.h>
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
namespace at::native::cpublas {
|
| 12 |
+
|
| 13 |
+
namespace internal {
|
| 14 |
+
void normalize_last_dims(
|
| 15 |
+
TransposeType transa, TransposeType transb,
|
| 16 |
+
int64_t m, int64_t n, int64_t k,
|
| 17 |
+
int64_t *lda, int64_t *ldb, int64_t *ldc);
|
| 18 |
+
} // namespace internal
|
| 19 |
+
|
| 20 |
+
using gemm_fn = void(*)(
|
| 21 |
+
at::ScalarType type,
|
| 22 |
+
TransposeType transa, TransposeType transb,
|
| 23 |
+
int64_t m, int64_t n, int64_t k,
|
| 24 |
+
const Scalar& alpha,
|
| 25 |
+
const void *a, int64_t lda,
|
| 26 |
+
const void *b, int64_t ldb,
|
| 27 |
+
const Scalar& beta,
|
| 28 |
+
void *c, int64_t ldc);
|
| 29 |
+
|
| 30 |
+
DECLARE_DISPATCH(gemm_fn, gemm_stub);
|
| 31 |
+
|
| 32 |
+
template <typename scalar_t>
|
| 33 |
+
void gemm(
|
| 34 |
+
TransposeType transa, TransposeType transb,
|
| 35 |
+
int64_t m, int64_t n, int64_t k,
|
| 36 |
+
at::opmath_type<scalar_t> alpha,
|
| 37 |
+
const scalar_t *a, int64_t lda,
|
| 38 |
+
const scalar_t *b, int64_t ldb,
|
| 39 |
+
at::opmath_type<scalar_t> beta,
|
| 40 |
+
scalar_t *c, int64_t ldc) {
|
| 41 |
+
internal::normalize_last_dims(transa, transb, m, n, k, &lda, &ldb, &ldc);
|
| 42 |
+
gemm_stub(
|
| 43 |
+
kCPU, c10::CppTypeToScalarType<scalar_t>::value,
|
| 44 |
+
transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
|
| 45 |
+
}
|
| 46 |
+
|
| 47 |
+
void gemm(
|
| 48 |
+
TransposeType transa, TransposeType transb,
|
| 49 |
+
int64_t m, int64_t n, int64_t k,
|
| 50 |
+
double alpha,
|
| 51 |
+
const double *a, int64_t lda,
|
| 52 |
+
const double *b, int64_t ldb,
|
| 53 |
+
double beta,
|
| 54 |
+
double *c, int64_t ldc);
|
| 55 |
+
|
| 56 |
+
void gemm(
|
| 57 |
+
TransposeType transa, TransposeType transb,
|
| 58 |
+
int64_t m, int64_t n, int64_t k,
|
| 59 |
+
float alpha,
|
| 60 |
+
const float *a, int64_t lda,
|
| 61 |
+
const float *b, int64_t ldb,
|
| 62 |
+
float beta,
|
| 63 |
+
float *c, int64_t ldc);
|
| 64 |
+
|
| 65 |
+
void gemm(
|
| 66 |
+
TransposeType transa, TransposeType transb,
|
| 67 |
+
int64_t m, int64_t n, int64_t k,
|
| 68 |
+
float alpha,
|
| 69 |
+
const at::BFloat16 *a, int64_t lda,
|
| 70 |
+
const at::BFloat16 *b, int64_t ldb,
|
| 71 |
+
float beta,
|
| 72 |
+
at::BFloat16 *c, int64_t ldc);
|
| 73 |
+
|
| 74 |
+
void gemm(
|
| 75 |
+
TransposeType transa, TransposeType transb,
|
| 76 |
+
int64_t m, int64_t n, int64_t k,
|
| 77 |
+
const float alpha,
|
| 78 |
+
const at::BFloat16 *a, int64_t lda,
|
| 79 |
+
const at::BFloat16 *b, int64_t ldb,
|
| 80 |
+
const float beta,
|
| 81 |
+
float *c, int64_t ldc);
|
| 82 |
+
|
| 83 |
+
void gemm(
|
| 84 |
+
TransposeType transa, TransposeType transb,
|
| 85 |
+
int64_t m, int64_t n, int64_t k,
|
| 86 |
+
float alpha,
|
| 87 |
+
const at::Half *a, int64_t lda,
|
| 88 |
+
const at::Half *b, int64_t ldb,
|
| 89 |
+
float beta,
|
| 90 |
+
at::Half *c, int64_t ldc);
|
| 91 |
+
|
| 92 |
+
void gemm(
|
| 93 |
+
TransposeType transa, TransposeType transb,
|
| 94 |
+
int64_t m, int64_t n, int64_t k,
|
| 95 |
+
const float alpha,
|
| 96 |
+
const at::Half *a, int64_t lda,
|
| 97 |
+
const at::Half *b, int64_t ldb,
|
| 98 |
+
const float beta,
|
| 99 |
+
float *c, int64_t ldc);
|
| 100 |
+
|
| 101 |
+
void gemm(
|
| 102 |
+
TransposeType transa, TransposeType transb,
|
| 103 |
+
int64_t m, int64_t n, int64_t k,
|
| 104 |
+
c10::complex<double> alpha,
|
| 105 |
+
const c10::complex<double> *a, int64_t lda,
|
| 106 |
+
const c10::complex<double> *b, int64_t ldb,
|
| 107 |
+
c10::complex<double> beta,
|
| 108 |
+
c10::complex<double> *c, int64_t ldc);
|
| 109 |
+
|
| 110 |
+
void gemm(
|
| 111 |
+
TransposeType transa, TransposeType transb,
|
| 112 |
+
int64_t m, int64_t n, int64_t k,
|
| 113 |
+
c10::complex<float> alpha,
|
| 114 |
+
const c10::complex<float> *a, int64_t lda,
|
| 115 |
+
const c10::complex<float> *b, int64_t ldb,
|
| 116 |
+
c10::complex<float> beta,
|
| 117 |
+
c10::complex<float> *c, int64_t ldc);
|
| 118 |
+
|
| 119 |
+
void gemm(
|
| 120 |
+
TransposeType transa, TransposeType transb,
|
| 121 |
+
int64_t m, int64_t n, int64_t k,
|
| 122 |
+
int64_t alpha,
|
| 123 |
+
const int64_t *a, int64_t lda,
|
| 124 |
+
const int64_t *b, int64_t ldb,
|
| 125 |
+
int64_t beta,
|
| 126 |
+
int64_t *c, int64_t ldc);
|
| 127 |
+
|
| 128 |
+
template <typename scalar_t>
|
| 129 |
+
void gemm_batched(
|
| 130 |
+
TransposeType transa, TransposeType transb,
|
| 131 |
+
int64_t batch_size, int64_t m, int64_t n, int64_t k,
|
| 132 |
+
scalar_t alpha,
|
| 133 |
+
const scalar_t * const *a, int64_t lda,
|
| 134 |
+
const scalar_t * const *b, int64_t ldb,
|
| 135 |
+
const scalar_t beta,
|
| 136 |
+
scalar_t * const *c, int64_t ldc);
|
| 137 |
+
|
| 138 |
+
template <typename scalar_t>
|
| 139 |
+
void gemm_batched_with_stride(
|
| 140 |
+
TransposeType transa, TransposeType transb,
|
| 141 |
+
int64_t batch_size, int64_t m, int64_t n, int64_t k,
|
| 142 |
+
scalar_t alpha,
|
| 143 |
+
const scalar_t *a, int64_t lda, int64_t batch_stride_a,
|
| 144 |
+
const scalar_t *b, int64_t ldb, int64_t batch_stride_b,
|
| 145 |
+
scalar_t beta,
|
| 146 |
+
scalar_t *c, int64_t ldc, int64_t batch_stride_c);
|
| 147 |
+
|
| 148 |
+
using axpy_fn = void(*)(at::ScalarType type, int64_t n, const Scalar& a, const void *x, int64_t incx, void *y, int64_t incy);
|
| 149 |
+
|
| 150 |
+
DECLARE_DISPATCH(axpy_fn, axpy_stub);
|
| 151 |
+
|
| 152 |
+
template<typename scalar_t>
|
| 153 |
+
void axpy(int64_t n, scalar_t a, const scalar_t *x, int64_t incx, scalar_t *y, int64_t incy){
|
| 154 |
+
if(n == 1)
|
| 155 |
+
{
|
| 156 |
+
incx = 1;
|
| 157 |
+
incy = 1;
|
| 158 |
+
}
|
| 159 |
+
axpy_stub(
|
| 160 |
+
kCPU, c10::CppTypeToScalarType<scalar_t>::value,
|
| 161 |
+
n, a, x, incx, y, incy);
|
| 162 |
+
}
|
| 163 |
+
|
| 164 |
+
void axpy(int64_t n, double a, const double *x, int64_t incx, double *y, int64_t incy);
|
| 165 |
+
void axpy(int64_t n, float a, const float *x, int64_t incx, float *y, int64_t incy);
|
| 166 |
+
void axpy(int64_t n, c10::complex<double> a, const c10::complex<double> *x, int64_t incx, c10::complex<double> *y, int64_t incy);
|
| 167 |
+
void axpy(int64_t n, c10::complex<float> a, const c10::complex<float> *x, int64_t incx, c10::complex<float> *y, int64_t incy);
|
| 168 |
+
|
| 169 |
+
using copy_fn = void(*)(at::ScalarType type, int64_t n, const void *x, int64_t incx, void *y, int64_t incy);
|
| 170 |
+
|
| 171 |
+
DECLARE_DISPATCH(copy_fn, copy_stub);
|
| 172 |
+
|
| 173 |
+
template<typename scalar_t>
|
| 174 |
+
void copy(int64_t n, const scalar_t *x, int64_t incx, scalar_t *y, int64_t incy) {
|
| 175 |
+
if(n == 1)
|
| 176 |
+
{
|
| 177 |
+
incx = 1;
|
| 178 |
+
incy = 1;
|
| 179 |
+
}
|
| 180 |
+
copy_stub(
|
| 181 |
+
kCPU, c10::CppTypeToScalarType<scalar_t>::value,
|
| 182 |
+
n, x, incx, y, incy);
|
| 183 |
+
}
|
| 184 |
+
|
| 185 |
+
void copy(int64_t n, const double *x, int64_t incx, double *y, int64_t incy);
|
| 186 |
+
void copy(int64_t n, const float *x, int64_t incx, float *y, int64_t incy);
|
| 187 |
+
void copy(int64_t n, const c10::complex<double> *x, int64_t incx, c10::complex<double> *y, int64_t incy);
|
| 188 |
+
void copy(int64_t n, const c10::complex<float> *x, int64_t incx, c10::complex<float> *y, int64_t incy);
|
| 189 |
+
|
| 190 |
+
// Batch-reduce GEMM
|
| 191 |
+
// Operates by the following formula:
|
| 192 |
+
// C = alpha * SUM(A[i] x B[i]) + beta * C, i = 0 to batch size
|
| 193 |
+
// A Base pointer to a tensor A.
|
| 194 |
+
// B Base pointer to a tensor B.
|
| 195 |
+
// C Pointer to a tensor C (accumulation buffer).
|
| 196 |
+
TORCH_API void brgemm(
|
| 197 |
+
int64_t M,
|
| 198 |
+
int64_t N,
|
| 199 |
+
int64_t K,
|
| 200 |
+
int64_t ld_a,
|
| 201 |
+
int64_t ld_b,
|
| 202 |
+
int64_t ld_c,
|
| 203 |
+
const float alpha,
|
| 204 |
+
const float beta,
|
| 205 |
+
const at::Half* A,
|
| 206 |
+
const at::Half* B,
|
| 207 |
+
float* C);
|
| 208 |
+
|
| 209 |
+
// Release brgemm hardware context
|
| 210 |
+
void brgemm_release();
|
| 211 |
+
|
| 212 |
+
// Pack B matrix to get better performance if needed
|
| 213 |
+
void pack(
|
| 214 |
+
int64_t K,
|
| 215 |
+
int64_t N,
|
| 216 |
+
int64_t ld_in,
|
| 217 |
+
int64_t ld_out,
|
| 218 |
+
ScalarType dt_in,
|
| 219 |
+
ScalarType dt_out,
|
| 220 |
+
const void* in,
|
| 221 |
+
void* out);
|
| 222 |
+
|
| 223 |
+
// Whether pack is needed in the platform.
|
| 224 |
+
bool need_pack(ScalarType dt_in);
|
| 225 |
+
|
| 226 |
+
} // namespace at::native::cpublas
|
.venv/lib/python3.11/site-packages/torch/include/ATen/native/CPUFallback.h
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/core/ivalue.h>
|
| 4 |
+
#include <ATen/core/stack.h>
|
| 5 |
+
#include <ATen/core/boxing/KernelFunction.h>
|
| 6 |
+
#include <ATen/core/dispatch/Dispatcher.h>
|
| 7 |
+
#include <c10/util/Metaprogramming.h>
|
| 8 |
+
#include <torch/library.h>
|
| 9 |
+
|
| 10 |
+
namespace at::native {
|
| 11 |
+
|
| 12 |
+
// This function implements a boxed fallback to CPU.
|
| 13 |
+
// External backends can add their own custom logging on top if it to customize their own CPU fallbacks.
|
| 14 |
+
TORCH_API void cpu_fallback(const c10::OperatorHandle& op, torch::jit::Stack* stack, bool error_on_views = false,
|
| 15 |
+
c10::DispatchKey cpu_dispatch_key = c10::DispatchKey::CPU);
|
| 16 |
+
|
| 17 |
+
// This is a helper function that backends can use to directly call their boxed CPU fallback
|
| 18 |
+
// TODO: update and add a usage example after https://github.com/pytorch/pytorch/pull/58092 lands.
|
| 19 |
+
template<c10::KernelFunction::BoxedKernelFunction* fallback_fn, class Op, bool symint, class ReturnType, class... ParameterTypes>
|
| 20 |
+
struct _call_fallback_fn final {};
|
| 21 |
+
|
| 22 |
+
template<c10::KernelFunction::BoxedKernelFunction* fallback_fn, class Op, bool symint, class ReturnType, class... ParameterTypes>
|
| 23 |
+
struct _call_fallback_fn<fallback_fn, Op, symint, ReturnType(ParameterTypes...)> final {
|
| 24 |
+
static ReturnType call(typename c10::maybe_keep_symint<symint, ParameterTypes>::type... args) {
|
| 25 |
+
auto op = c10::Dispatcher::singleton()
|
| 26 |
+
// TODO: figure out how to make compiler happy without dynamic casts
|
| 27 |
+
.findSchemaOrThrow((const char*) Op::name, (const char*) Op::overload_name)
|
| 28 |
+
//.findSchemaOrThrow("a", "b")
|
| 29 |
+
.typed<ReturnType (typename c10::maybe_keep_symint<symint, ParameterTypes>::type...)>();
|
| 30 |
+
return c10::impl::BoxedKernelWrapper<ReturnType (typename c10::maybe_keep_symint<symint, ParameterTypes>::type...)>::call(
|
| 31 |
+
c10::BoxedKernel::makeFromFunction<fallback_fn>(),
|
| 32 |
+
op,
|
| 33 |
+
c10::DispatchKeySet(), // we know that the cpu_fallback doesn't use the dispatch keyset.
|
| 34 |
+
// TODO: get std::forward<> to work
|
| 35 |
+
args...
|
| 36 |
+
);
|
| 37 |
+
}
|
| 38 |
+
};
|
| 39 |
+
|
| 40 |
+
template<c10::KernelFunction::BoxedKernelFunction* fallback_fn, class Op>
|
| 41 |
+
using call_fallback_fn_symint = _call_fallback_fn<fallback_fn, Op, true, typename Op::schema>;
|
| 42 |
+
|
| 43 |
+
template<c10::KernelFunction::BoxedKernelFunction* fallback_fn, class Op>
|
| 44 |
+
using call_fallback_fn = _call_fallback_fn<fallback_fn, Op, false, typename Op::schema>;
|
| 45 |
+
|
| 46 |
+
} // namespace at::native
|
.venv/lib/python3.11/site-packages/torch/include/ATen/native/CanUse32BitIndexMath.h
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
#include <c10/macros/Export.h>
|
| 3 |
+
#include <limits>
|
| 4 |
+
|
| 5 |
+
namespace at {
|
| 6 |
+
class TensorBase;
|
| 7 |
+
}
|
| 8 |
+
|
| 9 |
+
namespace at::native {
|
| 10 |
+
|
| 11 |
+
TORCH_API bool canUse32BitIndexMath(const at::TensorBase &t, int64_t max_elem=std::numeric_limits<int32_t>::max());
|
| 12 |
+
|
| 13 |
+
}
|
.venv/lib/python3.11/site-packages/torch/include/ATen/native/CompositeRandomAccessor.h
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/native/CompositeRandomAccessorCommon.h>
|
| 4 |
+
|
| 5 |
+
namespace at::native {
|
| 6 |
+
|
| 7 |
+
struct TupleInfoCPU {
|
| 8 |
+
template <typename ...Types>
|
| 9 |
+
using tuple = std::tuple<Types...>;
|
| 10 |
+
|
| 11 |
+
template <typename ...Types>
|
| 12 |
+
static constexpr auto tie(Types&... args) noexcept {
|
| 13 |
+
return std::tie(args...);
|
| 14 |
+
}
|
| 15 |
+
};
|
| 16 |
+
|
| 17 |
+
template <typename KeyAccessor, typename ValueAccessor>
|
| 18 |
+
using CompositeRandomAccessorCPU =
|
| 19 |
+
CompositeRandomAccessor<KeyAccessor, ValueAccessor, TupleInfoCPU>;
|
| 20 |
+
|
| 21 |
+
template <typename Values, typename References>
|
| 22 |
+
void swap(
|
| 23 |
+
references_holder<Values, References> rh1,
|
| 24 |
+
references_holder<Values, References> rh2
|
| 25 |
+
) {
|
| 26 |
+
return std::swap(rh1.data(), rh2.data());
|
| 27 |
+
}
|
| 28 |
+
|
| 29 |
+
template <int N, typename Values, typename References>
|
| 30 |
+
auto get(references_holder<Values, References> rh) -> decltype(std::get<N>(rh.data())) {
|
| 31 |
+
return std::get<N>(rh.data());
|
| 32 |
+
}
|
| 33 |
+
|
| 34 |
+
} // namespace at::native
|
.venv/lib/python3.11/site-packages/torch/include/ATen/native/CompositeRandomAccessorCommon.h
ADDED
|
@@ -0,0 +1,263 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <utility>
|
| 2 |
+
|
| 3 |
+
#pragma once
|
| 4 |
+
|
| 5 |
+
namespace at::native {
|
| 6 |
+
|
| 7 |
+
namespace {
|
| 8 |
+
|
| 9 |
+
// operator_brackets_proxy is used in
|
| 10 |
+
// CompositeRandomAccessor in place of operator[].
|
| 11 |
+
// For some iterators, references returned by operator[]
|
| 12 |
+
// could become invalid, operator_brackets_proxy tries to
|
| 13 |
+
// resolve that by making accessor[n] to be equivalent to
|
| 14 |
+
// *(accessor + n).
|
| 15 |
+
template <typename Accessor>
|
| 16 |
+
class operator_brackets_proxy {
|
| 17 |
+
using reference = typename std::iterator_traits<Accessor>::reference;
|
| 18 |
+
using value_type = typename std::iterator_traits<Accessor>::value_type;
|
| 19 |
+
|
| 20 |
+
public:
|
| 21 |
+
C10_HOST_DEVICE
|
| 22 |
+
operator_brackets_proxy(Accessor const& accessor)
|
| 23 |
+
: accessor(accessor)
|
| 24 |
+
{}
|
| 25 |
+
|
| 26 |
+
C10_HOST_DEVICE
|
| 27 |
+
operator reference() {
|
| 28 |
+
return *accessor;
|
| 29 |
+
}
|
| 30 |
+
|
| 31 |
+
C10_HOST_DEVICE
|
| 32 |
+
reference operator*() {
|
| 33 |
+
return *accessor;
|
| 34 |
+
}
|
| 35 |
+
|
| 36 |
+
C10_HOST_DEVICE
|
| 37 |
+
operator_brackets_proxy& operator=(value_type const& val) {
|
| 38 |
+
*accessor = val;
|
| 39 |
+
return *this;
|
| 40 |
+
}
|
| 41 |
+
|
| 42 |
+
private:
|
| 43 |
+
Accessor accessor;
|
| 44 |
+
};
|
| 45 |
+
|
| 46 |
+
}
|
| 47 |
+
|
| 48 |
+
// references_holder is used as a surrogate for the
|
| 49 |
+
// references type from std::iterator_traits in CompositeRandomAccessor.
|
| 50 |
+
// It is assumed in CompositeRandomAccessor that
|
| 51 |
+
// References = tuple<Types&...>,
|
| 52 |
+
// Values = tuple<Types...> by default,
|
| 53 |
+
// but they could be anything as long as References could be
|
| 54 |
+
// cast to Values.
|
| 55 |
+
// If you plan to use it with STL, for example, you will need to
|
| 56 |
+
// define 'swap` and `get`(aka std::get) methods.
|
| 57 |
+
template <typename Values, typename References>
|
| 58 |
+
class references_holder {
|
| 59 |
+
public:
|
| 60 |
+
using values = Values;
|
| 61 |
+
using references = References;
|
| 62 |
+
|
| 63 |
+
C10_HOST_DEVICE
|
| 64 |
+
references_holder(references refs)
|
| 65 |
+
: refs{std::move(refs)}
|
| 66 |
+
{}
|
| 67 |
+
|
| 68 |
+
C10_HOST_DEVICE
|
| 69 |
+
operator references() {
|
| 70 |
+
return refs;
|
| 71 |
+
}
|
| 72 |
+
|
| 73 |
+
C10_HOST_DEVICE
|
| 74 |
+
operator values() {
|
| 75 |
+
return refs;
|
| 76 |
+
}
|
| 77 |
+
|
| 78 |
+
C10_HOST_DEVICE
|
| 79 |
+
references_holder& operator=(values vals) {
|
| 80 |
+
refs = vals;
|
| 81 |
+
return *this;
|
| 82 |
+
}
|
| 83 |
+
|
| 84 |
+
C10_HOST_DEVICE
|
| 85 |
+
references& data() {
|
| 86 |
+
return refs;
|
| 87 |
+
}
|
| 88 |
+
|
| 89 |
+
protected:
|
| 90 |
+
references refs;
|
| 91 |
+
};
|
| 92 |
+
|
| 93 |
+
// CompositeRandomAccessor is essentially a simplified version of
|
| 94 |
+
// a random access iterator over two random access iterators.
|
| 95 |
+
// TupleInfo should contain a variadic type `tuple`, and a method `tie`,
|
| 96 |
+
// which constructs a tuple of references from a variadic list of arguments.
|
| 97 |
+
template <typename KeyAccessor, typename ValueAccessor, typename TupleInfo>
|
| 98 |
+
class CompositeRandomAccessor {
|
| 99 |
+
using self_type = CompositeRandomAccessor<KeyAccessor, ValueAccessor, TupleInfo>;
|
| 100 |
+
|
| 101 |
+
using key_accessor_value_type =
|
| 102 |
+
typename std::iterator_traits<KeyAccessor>::value_type;
|
| 103 |
+
using value_accessor_value_type =
|
| 104 |
+
typename std::iterator_traits<ValueAccessor>::value_type;
|
| 105 |
+
using key_accessor_reference_type =
|
| 106 |
+
typename std::iterator_traits<KeyAccessor>::reference;
|
| 107 |
+
using value_accessor_reference_type =
|
| 108 |
+
typename std::iterator_traits<ValueAccessor>::reference;
|
| 109 |
+
|
| 110 |
+
using composite_value_type = typename TupleInfo::template tuple<
|
| 111 |
+
key_accessor_value_type,
|
| 112 |
+
value_accessor_value_type>;
|
| 113 |
+
using composite_reference = typename TupleInfo::template tuple<
|
| 114 |
+
key_accessor_reference_type,
|
| 115 |
+
value_accessor_reference_type>;
|
| 116 |
+
|
| 117 |
+
public:
|
| 118 |
+
using value_type = composite_value_type;
|
| 119 |
+
using reference = references_holder<composite_value_type, composite_reference>;
|
| 120 |
+
// Note that CompositeRandomAccessor does not hold key and values
|
| 121 |
+
// in a specific datastructure, which means that a pointer to a (key, value)
|
| 122 |
+
// is not defined. Hence we just use a pointer type of the KeyAccessor.
|
| 123 |
+
using pointer = typename std::iterator_traits<KeyAccessor>::pointer;
|
| 124 |
+
using difference_type = typename std::iterator_traits<KeyAccessor>::difference_type;
|
| 125 |
+
using iterator_category = std::random_access_iterator_tag;
|
| 126 |
+
|
| 127 |
+
C10_HOST_DEVICE
|
| 128 |
+
CompositeRandomAccessor() = default;
|
| 129 |
+
|
| 130 |
+
C10_HOST_DEVICE
|
| 131 |
+
CompositeRandomAccessor(KeyAccessor keys, ValueAccessor values)
|
| 132 |
+
: keys(keys), values(values)
|
| 133 |
+
{}
|
| 134 |
+
|
| 135 |
+
// Pointer-like operations {
|
| 136 |
+
C10_HOST_DEVICE
|
| 137 |
+
reference operator*() const {
|
| 138 |
+
return TupleInfo::tie(*keys, *values);
|
| 139 |
+
}
|
| 140 |
+
|
| 141 |
+
// operator->() is supposed to return a pointer type.
|
| 142 |
+
// Since CompositeRandomAccessor does not hold pointers to pairs,
|
| 143 |
+
// we just return a pointer to a key.
|
| 144 |
+
C10_HOST_DEVICE
|
| 145 |
+
auto* operator->() const {
|
| 146 |
+
return keys.operator->();
|
| 147 |
+
}
|
| 148 |
+
|
| 149 |
+
C10_HOST_DEVICE
|
| 150 |
+
reference operator[](difference_type idx) {
|
| 151 |
+
return operator_brackets_proxy<self_type>(
|
| 152 |
+
CompositeRandomAccessor(keys + idx, values + idx)
|
| 153 |
+
);
|
| 154 |
+
}
|
| 155 |
+
// }
|
| 156 |
+
|
| 157 |
+
// Prefix/postfix increment/decrement {
|
| 158 |
+
C10_HOST_DEVICE
|
| 159 |
+
CompositeRandomAccessor& operator++() {
|
| 160 |
+
++keys;
|
| 161 |
+
++values;
|
| 162 |
+
return *this;
|
| 163 |
+
}
|
| 164 |
+
|
| 165 |
+
C10_HOST_DEVICE
|
| 166 |
+
CompositeRandomAccessor operator++(int) {
|
| 167 |
+
CompositeRandomAccessor copy(*this);
|
| 168 |
+
++*this;
|
| 169 |
+
return copy;
|
| 170 |
+
}
|
| 171 |
+
|
| 172 |
+
C10_HOST_DEVICE
|
| 173 |
+
CompositeRandomAccessor& operator--() {
|
| 174 |
+
--keys;
|
| 175 |
+
--values;
|
| 176 |
+
return *this;
|
| 177 |
+
}
|
| 178 |
+
|
| 179 |
+
C10_HOST_DEVICE
|
| 180 |
+
CompositeRandomAccessor operator--(int) {
|
| 181 |
+
CompositeRandomAccessor copy(*this);
|
| 182 |
+
--*this;
|
| 183 |
+
return copy;
|
| 184 |
+
}
|
| 185 |
+
// }
|
| 186 |
+
|
| 187 |
+
// Arithmetic operations {
|
| 188 |
+
C10_HOST_DEVICE
|
| 189 |
+
CompositeRandomAccessor& operator+=(difference_type offset) {
|
| 190 |
+
keys += offset;
|
| 191 |
+
values += offset;
|
| 192 |
+
return *this;
|
| 193 |
+
}
|
| 194 |
+
|
| 195 |
+
C10_HOST_DEVICE
|
| 196 |
+
CompositeRandomAccessor operator+(difference_type offset) const {
|
| 197 |
+
return CompositeRandomAccessor(keys + offset, values + offset);
|
| 198 |
+
}
|
| 199 |
+
|
| 200 |
+
C10_HOST_DEVICE
|
| 201 |
+
friend CompositeRandomAccessor operator+(
|
| 202 |
+
difference_type offset,
|
| 203 |
+
const CompositeRandomAccessor& accessor
|
| 204 |
+
) {
|
| 205 |
+
return accessor + offset;
|
| 206 |
+
}
|
| 207 |
+
|
| 208 |
+
C10_HOST_DEVICE
|
| 209 |
+
CompositeRandomAccessor& operator-=(difference_type offset) {
|
| 210 |
+
keys -= offset;
|
| 211 |
+
values -= offset;
|
| 212 |
+
return *this;
|
| 213 |
+
}
|
| 214 |
+
|
| 215 |
+
C10_HOST_DEVICE
|
| 216 |
+
CompositeRandomAccessor operator-(difference_type offset) const {
|
| 217 |
+
return CompositeRandomAccessor(keys - offset, values - offset);
|
| 218 |
+
}
|
| 219 |
+
|
| 220 |
+
C10_HOST_DEVICE
|
| 221 |
+
difference_type operator-(const CompositeRandomAccessor& other) const {
|
| 222 |
+
return keys - other.keys;
|
| 223 |
+
}
|
| 224 |
+
// }
|
| 225 |
+
|
| 226 |
+
// Comparison operators {
|
| 227 |
+
C10_HOST_DEVICE
|
| 228 |
+
bool operator==(const CompositeRandomAccessor& other) const {
|
| 229 |
+
return keys == other.keys;
|
| 230 |
+
}
|
| 231 |
+
|
| 232 |
+
C10_HOST_DEVICE
|
| 233 |
+
bool operator!=(const CompositeRandomAccessor& other) const {
|
| 234 |
+
return keys != other.keys;
|
| 235 |
+
}
|
| 236 |
+
|
| 237 |
+
C10_HOST_DEVICE
|
| 238 |
+
bool operator<(const CompositeRandomAccessor& other) const {
|
| 239 |
+
return keys < other.keys;
|
| 240 |
+
}
|
| 241 |
+
|
| 242 |
+
C10_HOST_DEVICE
|
| 243 |
+
bool operator<=(const CompositeRandomAccessor& other) const {
|
| 244 |
+
return keys <= other.keys;
|
| 245 |
+
}
|
| 246 |
+
|
| 247 |
+
C10_HOST_DEVICE
|
| 248 |
+
bool operator>(const CompositeRandomAccessor& other) const {
|
| 249 |
+
return keys > other.keys;
|
| 250 |
+
}
|
| 251 |
+
|
| 252 |
+
C10_HOST_DEVICE
|
| 253 |
+
bool operator>=(const CompositeRandomAccessor& other) const {
|
| 254 |
+
return keys >= other.keys;
|
| 255 |
+
}
|
| 256 |
+
// }
|
| 257 |
+
|
| 258 |
+
protected:
|
| 259 |
+
KeyAccessor keys;
|
| 260 |
+
ValueAccessor values;
|
| 261 |
+
};
|
| 262 |
+
|
| 263 |
+
} // namespace at::native
|
.venv/lib/python3.11/site-packages/torch/include/ATen/native/ConvUtils.h
ADDED
|
@@ -0,0 +1,449 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
#include <ATen/core/Tensor.h>
|
| 3 |
+
#include <ATen/TensorUtils.h>
|
| 4 |
+
#include <ATen/detail/CUDAHooksInterface.h>
|
| 5 |
+
#include <ATen/native/DispatchStub.h>
|
| 6 |
+
#include <c10/util/env.h>
|
| 7 |
+
#include <c10/util/irange.h>
|
| 8 |
+
|
| 9 |
+
#include <utility>
|
| 10 |
+
|
| 11 |
+
namespace at::native {
|
| 12 |
+
|
| 13 |
+
using conv_depthwise2d_backward_fn = std::tuple<at::Tensor,at::Tensor>(*)(
|
| 14 |
+
const at::Tensor&, const at::Tensor&, const at::Tensor&, at::IntArrayRef, at::IntArrayRef,
|
| 15 |
+
at::IntArrayRef, at::IntArrayRef, std::array<bool, 2>);
|
| 16 |
+
DECLARE_DISPATCH(conv_depthwise2d_backward_fn, conv_depthwise2d_backward_stub);
|
| 17 |
+
using conv_depthwise3d_backward_fn = std::tuple<at::Tensor,at::Tensor,at::Tensor>(*)(
|
| 18 |
+
const at::Tensor&, const at::Tensor&, const at::Tensor&, at::IntArrayRef, at::IntArrayRef,
|
| 19 |
+
at::IntArrayRef, at::IntArrayRef, std::array<bool, 3>);
|
| 20 |
+
DECLARE_DISPATCH(conv_depthwise3d_backward_fn, conv_depthwise3d_backward_stub);
|
| 21 |
+
using cudnn_convolution_backward_fn = std::tuple<at::Tensor,at::Tensor>(*)(
|
| 22 |
+
const at::Tensor&, const at::Tensor&, const at::Tensor&, at::IntArrayRef, at::IntArrayRef,
|
| 23 |
+
at::IntArrayRef, int64_t, bool, bool, bool, std::array<bool,2>);
|
| 24 |
+
DECLARE_DISPATCH(cudnn_convolution_backward_fn, cudnn_convolution_backward_stub);
|
| 25 |
+
using mps_convolution_backward_fn = std::tuple<at::Tensor,at::Tensor,at::Tensor>(*)(
|
| 26 |
+
const at::Tensor&, const at::Tensor&, const at::Tensor&, at::IntArrayRef, at::IntArrayRef,
|
| 27 |
+
at::IntArrayRef, int64_t, std::array<bool,3>);
|
| 28 |
+
DECLARE_DISPATCH(mps_convolution_backward_fn, mps_convolution_backward_stub);
|
| 29 |
+
using cudnn_convolution_transpose_backward_fn = std::tuple<at::Tensor,at::Tensor>(*)(
|
| 30 |
+
const at::Tensor&, const at::Tensor&, const at::Tensor&, at::IntArrayRef, at::IntArrayRef,
|
| 31 |
+
at::IntArrayRef, at::IntArrayRef, int64_t, bool, bool, bool, std::array<bool,2>);
|
| 32 |
+
DECLARE_DISPATCH(cudnn_convolution_transpose_backward_fn, cudnn_convolution_transpose_backward_stub);
|
| 33 |
+
using miopen_convolution_backward_fn = std::tuple<at::Tensor,at::Tensor,at::Tensor>(*)(
|
| 34 |
+
const at::Tensor&, const at::Tensor&, const at::Tensor&, at::IntArrayRef, at::IntArrayRef,
|
| 35 |
+
at::IntArrayRef, int64_t, bool, bool, std::array<bool,3>);
|
| 36 |
+
DECLARE_DISPATCH(miopen_convolution_backward_fn, miopen_convolution_backward_stub);
|
| 37 |
+
using miopen_convolution_transpose_backward_fn = std::tuple<at::Tensor,at::Tensor,at::Tensor>(*)(
|
| 38 |
+
const at::Tensor&, const at::Tensor&, const at::Tensor&, at::IntArrayRef, at::IntArrayRef,
|
| 39 |
+
at::IntArrayRef, at::IntArrayRef, int64_t, bool, bool, std::array<bool,3>);
|
| 40 |
+
DECLARE_DISPATCH(miopen_convolution_transpose_backward_fn, miopen_convolution_transpose_backward_stub);
|
| 41 |
+
using miopen_depthwise_convolution_backward_fn = std::tuple<at::Tensor,at::Tensor,at::Tensor>(*)(
|
| 42 |
+
const at::Tensor&, const at::Tensor&, const at::Tensor&, at::IntArrayRef, at::IntArrayRef,
|
| 43 |
+
at::IntArrayRef, int64_t, bool, bool, std::array<bool,3>);
|
| 44 |
+
DECLARE_DISPATCH(miopen_depthwise_convolution_backward_fn, miopen_depthwise_convolution_backward_stub);
|
| 45 |
+
using mkldnn_convolution_backward_fn = std::tuple<at::Tensor,at::Tensor,at::Tensor>(*)(
|
| 46 |
+
const at::Tensor&, const at::Tensor&, const at::Tensor&, at::IntArrayRef, at::IntArrayRef,
|
| 47 |
+
at::IntArrayRef, int64_t, std::array<bool,3>);
|
| 48 |
+
DECLARE_DISPATCH(mkldnn_convolution_backward_fn, mkldnn_convolution_backward_stub);
|
| 49 |
+
using mkldnn_convolution_transpose_fn = Tensor(*)(const Tensor&, const Tensor&, const std::optional<Tensor>&,
|
| 50 |
+
IntArrayRef, IntArrayRef, IntArrayRef, IntArrayRef, int64_t);
|
| 51 |
+
DECLARE_DISPATCH(mkldnn_convolution_transpose_fn, mkldnn_convolution_transpose_stub);
|
| 52 |
+
using mkldnn_convolution_transpose_backward_fn = std::tuple<at::Tensor,at::Tensor,at::Tensor>(*)(
|
| 53 |
+
const at::Tensor&, const at::Tensor&, const at::Tensor&, at::IntArrayRef, at::IntArrayRef,
|
| 54 |
+
at::IntArrayRef, at::IntArrayRef, int64_t, std::array<bool,3>);
|
| 55 |
+
DECLARE_DISPATCH(mkldnn_convolution_transpose_backward_fn, mkldnn_convolution_transpose_backward_stub);
|
| 56 |
+
using slow_conv_dilated2d_backward_fn = std::tuple<at::Tensor,at::Tensor,at::Tensor>(*)(
|
| 57 |
+
const at::Tensor&, const at::Tensor&, const at::Tensor&, at::IntArrayRef, at::IntArrayRef,
|
| 58 |
+
at::IntArrayRef, at::IntArrayRef, std::array<bool, 3>);
|
| 59 |
+
DECLARE_DISPATCH(slow_conv_dilated2d_backward_fn, slow_conv_dilated2d_backward_stub);
|
| 60 |
+
using slow_conv_dilated3d_backward_fn = std::tuple<at::Tensor,at::Tensor,at::Tensor>(*)(
|
| 61 |
+
const at::Tensor&, const at::Tensor&, const at::Tensor&, at::IntArrayRef, at::IntArrayRef,
|
| 62 |
+
at::IntArrayRef, at::IntArrayRef, std::array<bool, 3>);
|
| 63 |
+
DECLARE_DISPATCH(slow_conv_dilated3d_backward_fn, slow_conv_dilated3d_backward_stub);
|
| 64 |
+
using slow_conv_transpose2d_backward_fn = std::tuple<at::Tensor,at::Tensor,at::Tensor>(*)(
|
| 65 |
+
const at::Tensor&, const at::Tensor&, const at::Tensor&, at::IntArrayRef, at::IntArrayRef,
|
| 66 |
+
at::IntArrayRef, at::IntArrayRef, at::IntArrayRef, std::array<bool,3>);
|
| 67 |
+
DECLARE_DISPATCH(slow_conv_transpose2d_backward_fn, slow_conv_transpose2d_backward_stub);
|
| 68 |
+
using slow_conv_transpose3d_backward_fn = std::tuple<at::Tensor,at::Tensor,at::Tensor>(*)(
|
| 69 |
+
const at::Tensor&, const at::Tensor&, const at::Tensor&, at::IntArrayRef, at::IntArrayRef,
|
| 70 |
+
at::IntArrayRef, at::IntArrayRef, at::IntArrayRef, std::array<bool,3>);
|
| 71 |
+
DECLARE_DISPATCH(slow_conv_transpose3d_backward_fn, slow_conv_transpose3d_backward_stub);
|
| 72 |
+
|
| 73 |
+
namespace {
|
| 74 |
+
bool is_cudnnv8_heuristic_mode_b() {
|
| 75 |
+
static const bool is_cudnnv8_heuristic_mode_b = c10::utils::check_env("TORCH_CUDNN_USE_HEURISTIC_MODE_B") == true;
|
| 76 |
+
return is_cudnnv8_heuristic_mode_b;
|
| 77 |
+
}
|
| 78 |
+
}
|
| 79 |
+
|
| 80 |
+
inline bool cudnnv8_enabled_check_debug() {
|
| 81 |
+
static bool cudnnv8_flag = c10::utils::check_env("TORCH_CUDNN_V8_API_DISABLED") != true;
|
| 82 |
+
static bool cudnnv8_debug = c10::utils::check_env("TORCH_CUDNN_V8_API_DEBUG") == true;
|
| 83 |
+
static uint8_t cudnnv8_debugcount = 0;
|
| 84 |
+
if (cudnnv8_debug == 1 && cudnnv8_debugcount < 10) {
|
| 85 |
+
TORCH_WARN("TORCH_CUDNN_V8_DEBUG ON, V8 ON: ", cudnnv8_flag, " TORCH_CUDNN_USE_HEURISTIC_MODE B: ", is_cudnnv8_heuristic_mode_b());
|
| 86 |
+
cudnnv8_debugcount++;
|
| 87 |
+
}
|
| 88 |
+
return cudnnv8_flag == 1;
|
| 89 |
+
}
|
| 90 |
+
|
| 91 |
+
inline bool cudnnv8_use_heur_mode_b() {
|
| 92 |
+
return is_cudnnv8_heuristic_mode_b();
|
| 93 |
+
}
|
| 94 |
+
|
| 95 |
+
// Keep in sync with py::enum_ in Module.cpp
|
| 96 |
+
enum class ConvBackend {
|
| 97 |
+
CudaDepthwise2d,
|
| 98 |
+
CudaDepthwise3d,
|
| 99 |
+
Cudnn,
|
| 100 |
+
CudnnTranspose,
|
| 101 |
+
Empty,
|
| 102 |
+
Miopen,
|
| 103 |
+
MiopenDepthwise,
|
| 104 |
+
MiopenTranspose,
|
| 105 |
+
Mkldnn,
|
| 106 |
+
MkldnnTranspose,
|
| 107 |
+
MkldnnEmpty,
|
| 108 |
+
NnpackSpatial,
|
| 109 |
+
Overrideable,
|
| 110 |
+
Slow2d,
|
| 111 |
+
Slow3d,
|
| 112 |
+
SlowDilated2d,
|
| 113 |
+
SlowDilated3d,
|
| 114 |
+
SlowTranspose2d,
|
| 115 |
+
SlowTranspose3d,
|
| 116 |
+
Winograd3x3Depthwise,
|
| 117 |
+
Xnnpack2d,
|
| 118 |
+
Mps,
|
| 119 |
+
MpsTranspose,
|
| 120 |
+
};
|
| 121 |
+
|
| 122 |
+
// Overload for selecting the convolution backend from the full set of convolution inputs.
|
| 123 |
+
// This overload is exposed to python for testing, etc.
|
| 124 |
+
TORCH_API ConvBackend select_conv_backend(
|
| 125 |
+
const Tensor& input, const Tensor& weight, const std::optional<Tensor>& bias_opt,
|
| 126 |
+
SymIntArrayRef stride, SymIntArrayRef padding, SymIntArrayRef dilation,
|
| 127 |
+
bool transposed, SymIntArrayRef output_padding, c10::SymInt groups, const at::OptionalSymIntArrayRef bias_sizes_opt);
|
| 128 |
+
|
| 129 |
+
TORCH_API at::MemoryFormat _determine_backend_memory_format(const Tensor& input,
|
| 130 |
+
const Tensor& weight,
|
| 131 |
+
const ConvBackend backend);
|
| 132 |
+
|
| 133 |
+
// ---------------------------------------------------------------------
|
| 134 |
+
//
|
| 135 |
+
// Math
|
| 136 |
+
//
|
| 137 |
+
// ---------------------------------------------------------------------
|
| 138 |
+
|
| 139 |
+
constexpr int input_batch_size_dim = 0; // also grad_input
|
| 140 |
+
constexpr int input_channels_dim = 1;
|
| 141 |
+
constexpr int output_batch_size_dim = 0; // also grad_output
|
| 142 |
+
constexpr int output_channels_dim = 1;
|
| 143 |
+
constexpr int weight_output_channels_dim = 0;
|
| 144 |
+
constexpr int weight_input_channels_dim = 1;
|
| 145 |
+
|
| 146 |
+
// Often written as 2 + max_dim (extra dims for batch size and channels)
|
| 147 |
+
constexpr int max_dim = 3;
|
| 148 |
+
|
| 149 |
+
// ---------------------------------------------------------------------
|
| 150 |
+
//
|
| 151 |
+
// Checking
|
| 152 |
+
//
|
| 153 |
+
// ---------------------------------------------------------------------
|
| 154 |
+
|
| 155 |
+
// Used on pad, stride and dilation
|
| 156 |
+
static void check_args(CheckedFrom c, IntArrayRef args, size_t expected_size, const char* arg_name)
|
| 157 |
+
{
|
| 158 |
+
TORCH_CHECK(args.size() <= expected_size,
|
| 159 |
+
"Too many ", arg_name, " values (", args.size(), ") supplied, expecting ",
|
| 160 |
+
expected_size, " (while checking arguments for ", c, ")");
|
| 161 |
+
TORCH_CHECK(args.size() >= expected_size,
|
| 162 |
+
"Not enough ", arg_name, " values (", args.size(), ") supplied, expecting ",
|
| 163 |
+
expected_size, " (while checking arguments for ", c, ")");
|
| 164 |
+
|
| 165 |
+
auto num_negative_values = std::count_if(args.begin(), args.end(), [](int x){return x < 0;});
|
| 166 |
+
if (num_negative_values > 0){
|
| 167 |
+
std::stringstream ss;
|
| 168 |
+
ss << arg_name << " should be greater than zero but got (";
|
| 169 |
+
std::copy(args.begin(), args.end() - 1, std::ostream_iterator<int>(ss,", "));
|
| 170 |
+
ss << args.back() << ")" << " (while checking arguments for " << c << ")";
|
| 171 |
+
AT_ERROR(ss.str());
|
| 172 |
+
}
|
| 173 |
+
}
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
// NOTE [ Convolution checks ]
|
| 177 |
+
//
|
| 178 |
+
// NB: For many call sites, it is not strictly necessary to check all of
|
| 179 |
+
// these relationships (for example, for forward convolution, we compute
|
| 180 |
+
// the size of output ourselves, so we don't actually need to check
|
| 181 |
+
// output. However, writing a single function that does everything
|
| 182 |
+
// means we get to reuse it for both forwards and all backwards
|
| 183 |
+
// variants, even when the set of "real" inputs varies. The magic of
|
| 184 |
+
// relational computing!
|
| 185 |
+
//
|
| 186 |
+
// (There is one downside, which is that it is slightly harder to write
|
| 187 |
+
// error messages which are able to distinguish between real inputs
|
| 188 |
+
// (which the user can change) and computed inputs (which the user can
|
| 189 |
+
// only indirectly affect). It would be an interesting exercise to
|
| 190 |
+
// come up with a general framework to handle such situations.)
|
| 191 |
+
inline void convolution_shape_check(
|
| 192 |
+
CheckedFrom c,
|
| 193 |
+
const TensorGeometryArg& input, const TensorGeometryArg& weight, const TensorGeometryArg& output,
|
| 194 |
+
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups)
|
| 195 |
+
{
|
| 196 |
+
check_args(c, padding, input->dim() - 2, "padding");
|
| 197 |
+
check_args(c, stride, padding.size(), "stride");
|
| 198 |
+
check_args(c, dilation, padding.size(), "dilation");
|
| 199 |
+
|
| 200 |
+
// Input
|
| 201 |
+
checkDimRange(c, input, 3, 6 /* exclusive */);
|
| 202 |
+
checkSize_symint(c, input, input_channels_dim, weight->size(1) * groups);
|
| 203 |
+
|
| 204 |
+
// Weight
|
| 205 |
+
checkSameDim(c, input, weight);
|
| 206 |
+
|
| 207 |
+
// TODO: check that output->size() matches output_sizes
|
| 208 |
+
// TODO: check that weight matches output->sizes()
|
| 209 |
+
checkSameDim(c, input, output);
|
| 210 |
+
}
|
| 211 |
+
|
| 212 |
+
// NB: conv_output_size and conv_input_size are not bijections,
|
| 213 |
+
// as conv_output_size loses information; this is why conv_input_size
|
| 214 |
+
// takes an extra output_padding argument to resolve the ambiguity.
|
| 215 |
+
|
| 216 |
+
template <typename T>
|
| 217 |
+
inline std::vector<T> _conv_output_size(
|
| 218 |
+
ArrayRef<T> input_size, ArrayRef<T> weight_size,
|
| 219 |
+
ArrayRef<T> padding, ArrayRef<T> stride, ArrayRef<T> dilation = ArrayRef<T>()
|
| 220 |
+
) {
|
| 221 |
+
// ASSERT(input_size.size() > 2)
|
| 222 |
+
// ASSERT(input_size.size() == weight_size.size())
|
| 223 |
+
bool has_dilation = !dilation.empty();
|
| 224 |
+
auto dim = input_size.size();
|
| 225 |
+
std::vector<T> output_size(dim);
|
| 226 |
+
output_size[0] = input_size[input_batch_size_dim];
|
| 227 |
+
output_size[1] = weight_size[weight_output_channels_dim];
|
| 228 |
+
for (const auto d : c10::irange(2, dim)) {
|
| 229 |
+
auto dilation_ = has_dilation ? dilation[d - 2] : 1;
|
| 230 |
+
auto kernel = dilation_ * (weight_size[d] - 1) + 1;
|
| 231 |
+
output_size[d] = (input_size[d] + (2 * padding[d - 2]) - kernel) / stride[d - 2] + 1;
|
| 232 |
+
}
|
| 233 |
+
return output_size;
|
| 234 |
+
}
|
| 235 |
+
|
| 236 |
+
inline std::vector<int64_t> conv_output_size(
|
| 237 |
+
IntArrayRef input_size, IntArrayRef weight_size,
|
| 238 |
+
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation = IntArrayRef()
|
| 239 |
+
) {
|
| 240 |
+
return _conv_output_size(input_size, weight_size, padding, stride, dilation);
|
| 241 |
+
}
|
| 242 |
+
|
| 243 |
+
inline std::vector<c10::SymInt> conv_output_size(
|
| 244 |
+
SymIntArrayRef input_size, SymIntArrayRef weight_size,
|
| 245 |
+
SymIntArrayRef padding, SymIntArrayRef stride, SymIntArrayRef dilation = SymIntArrayRef()
|
| 246 |
+
) {
|
| 247 |
+
return _conv_output_size(input_size, weight_size, padding, stride, dilation);
|
| 248 |
+
}
|
| 249 |
+
|
| 250 |
+
template <typename T>
|
| 251 |
+
std::vector<T> _conv_input_size(
|
| 252 |
+
ArrayRef<T> output_size, ArrayRef<T> weight_size,
|
| 253 |
+
ArrayRef<T> padding, ArrayRef<T> output_padding, ArrayRef<T> stride, ArrayRef<T> dilation, T groups
|
| 254 |
+
) {
|
| 255 |
+
// ASSERT(output_size.size() > 2)
|
| 256 |
+
// ASSERT(output_size.size() == weight_size.size())
|
| 257 |
+
auto dim = output_size.size();
|
| 258 |
+
std::vector<T> input_size(dim);
|
| 259 |
+
input_size[0] = output_size[output_batch_size_dim];
|
| 260 |
+
input_size[1] = weight_size[weight_input_channels_dim] * groups;
|
| 261 |
+
for (const auto d : c10::irange(2, dim)) {
|
| 262 |
+
auto kernel = (weight_size[d] - 1) * dilation[d - 2] + 1;
|
| 263 |
+
input_size[d] = (output_size[d] - 1) * stride[d - 2] - (padding[d - 2] * 2) +
|
| 264 |
+
kernel + output_padding[d - 2];
|
| 265 |
+
}
|
| 266 |
+
return input_size;
|
| 267 |
+
}
|
| 268 |
+
|
| 269 |
+
inline std::vector<c10::SymInt> conv_input_size(
|
| 270 |
+
SymIntArrayRef output_size, SymIntArrayRef weight_size,
|
| 271 |
+
SymIntArrayRef padding, SymIntArrayRef output_padding, SymIntArrayRef stride, SymIntArrayRef dilation, c10::SymInt groups
|
| 272 |
+
) {
|
| 273 |
+
return _conv_input_size(output_size, weight_size, padding, output_padding, stride, dilation, std::move(groups));
|
| 274 |
+
}
|
| 275 |
+
|
| 276 |
+
inline std::vector<int64_t> conv_input_size(
|
| 277 |
+
IntArrayRef output_size, IntArrayRef weight_size,
|
| 278 |
+
IntArrayRef padding, IntArrayRef output_padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups
|
| 279 |
+
) {
|
| 280 |
+
return _conv_input_size(output_size, weight_size, padding, output_padding, stride, dilation, groups);
|
| 281 |
+
}
|
| 282 |
+
|
| 283 |
+
template <typename T>
|
| 284 |
+
std::vector<T> _conv_weight_size(
|
| 285 |
+
ArrayRef<T> input_size, ArrayRef<T> output_size,
|
| 286 |
+
ArrayRef<T> padding, ArrayRef<T> output_padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups
|
| 287 |
+
) {
|
| 288 |
+
auto dim = input_size.size();
|
| 289 |
+
std::vector<T> weight_size(dim);
|
| 290 |
+
weight_size[0] = output_size[1];
|
| 291 |
+
weight_size[1] = input_size[1] / groups;
|
| 292 |
+
for (const auto d : c10::irange(2, dim)) {
|
| 293 |
+
auto kernel = input_size[d] - (output_size[d] - 1) * stride[d - 2]
|
| 294 |
+
+ padding[d - 2] * 2 - output_padding[d - 2];
|
| 295 |
+
weight_size[d] = (kernel - 1) / dilation[d - 2] + 1;
|
| 296 |
+
}
|
| 297 |
+
return weight_size;
|
| 298 |
+
}
|
| 299 |
+
|
| 300 |
+
inline std::vector<c10::SymInt> conv_weight_size(
|
| 301 |
+
SymIntArrayRef input_size, SymIntArrayRef output_size,
|
| 302 |
+
SymIntArrayRef padding, SymIntArrayRef output_padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups
|
| 303 |
+
) {
|
| 304 |
+
return _conv_weight_size(input_size, output_size, padding, output_padding, stride, dilation, groups);
|
| 305 |
+
}
|
| 306 |
+
|
| 307 |
+
inline std::vector<int64_t> conv_weight_size(
|
| 308 |
+
IntArrayRef input_size, IntArrayRef output_size,
|
| 309 |
+
IntArrayRef padding, IntArrayRef output_padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups
|
| 310 |
+
) {
|
| 311 |
+
return _conv_weight_size(input_size, output_size, padding, output_padding, stride, dilation, groups);
|
| 312 |
+
}
|
| 313 |
+
|
| 314 |
+
inline Tensor reshape_bias(int64_t dim, const Tensor& bias) {
|
| 315 |
+
std::vector<int64_t> shape(dim, 1);
|
| 316 |
+
shape[1] = -1;
|
| 317 |
+
return bias.reshape(shape);
|
| 318 |
+
}
|
| 319 |
+
|
| 320 |
+
inline at::MemoryFormat cudnn_conv_suggest_memory_format(const at::Tensor& input, const at::Tensor& weight) {
|
| 321 |
+
// disable NHWC for float64 input.
|
| 322 |
+
if (!at::detail::getCUDAHooks().compiledWithCuDNN() ||
|
| 323 |
+
input.scalar_type() == at::kDouble ||
|
| 324 |
+
weight.scalar_type() == at::kDouble) {
|
| 325 |
+
return at::MemoryFormat::Contiguous;
|
| 326 |
+
}
|
| 327 |
+
long cudnn_version = at::detail::getCUDAHooks().versionCuDNN();
|
| 328 |
+
auto input_memory_format = input.suggest_memory_format();
|
| 329 |
+
auto weight_memory_format = weight.suggest_memory_format();
|
| 330 |
+
auto weight_ndim = weight.ndimension();
|
| 331 |
+
|
| 332 |
+
bool can_use_cudnn_channels_last_2d = (cudnn_version >= 7603) && (weight_ndim == 4) && (
|
| 333 |
+
(input_memory_format == at::MemoryFormat::ChannelsLast) ||
|
| 334 |
+
(weight_memory_format == at::MemoryFormat::ChannelsLast)
|
| 335 |
+
);
|
| 336 |
+
if (can_use_cudnn_channels_last_2d) {
|
| 337 |
+
return at::MemoryFormat::ChannelsLast;
|
| 338 |
+
}
|
| 339 |
+
|
| 340 |
+
bool can_use_cudnn_channels_last_3d = (cudnn_version >= 8005) && (weight_ndim == 5) && (
|
| 341 |
+
(input_memory_format == at::MemoryFormat::ChannelsLast3d) ||
|
| 342 |
+
(weight_memory_format == at::MemoryFormat::ChannelsLast3d)
|
| 343 |
+
);
|
| 344 |
+
if (can_use_cudnn_channels_last_3d) {
|
| 345 |
+
return at::MemoryFormat::ChannelsLast3d;
|
| 346 |
+
}
|
| 347 |
+
|
| 348 |
+
return at::MemoryFormat::Contiguous;
|
| 349 |
+
}
|
| 350 |
+
|
| 351 |
+
// controls whether emptyCache will be called following cudnn conv benchmarking
|
| 352 |
+
TORCH_API void _cudnn_set_conv_benchmark_empty_cache(bool enable);
|
| 353 |
+
TORCH_API bool _cudnn_get_conv_benchmark_empty_cache();
|
| 354 |
+
|
| 355 |
+
|
| 356 |
+
inline bool miopen_conv_use_channels_last(const at::Tensor& input, const at::Tensor& weight) {
|
| 357 |
+
|
| 358 |
+
// disable NHWC for float64 input.
|
| 359 |
+
if (!at::detail::getCUDAHooks().compiledWithMIOpen() ||
|
| 360 |
+
input.scalar_type() == at::kDouble ||
|
| 361 |
+
weight.scalar_type() == at::kDouble) {
|
| 362 |
+
return false;
|
| 363 |
+
}
|
| 364 |
+
|
| 365 |
+
bool can_use_miopen_channels_last_2d = false;
|
| 366 |
+
// TODO: Remove PYTORCH_MIOPEN_SUGGEST_NHWC once ROCm officially supports NHWC in MIOpen
|
| 367 |
+
// See #64427
|
| 368 |
+
static std::optional<bool> PYTORCH_MIOPEN_SUGGEST_NHWC = c10::utils::check_env("PYTORCH_MIOPEN_SUGGEST_NHWC");
|
| 369 |
+
|
| 370 |
+
auto input_memory_format = input.suggest_memory_format();
|
| 371 |
+
auto weight_memory_format = weight.suggest_memory_format();
|
| 372 |
+
|
| 373 |
+
can_use_miopen_channels_last_2d = PYTORCH_MIOPEN_SUGGEST_NHWC && *PYTORCH_MIOPEN_SUGGEST_NHWC && (
|
| 374 |
+
( (input_memory_format == at::MemoryFormat::ChannelsLast) ||
|
| 375 |
+
(weight_memory_format == at::MemoryFormat::ChannelsLast) )
|
| 376 |
+
);
|
| 377 |
+
|
| 378 |
+
bool can_use_miopen_channels_last_3d = false;
|
| 379 |
+
|
| 380 |
+
return can_use_miopen_channels_last_2d || can_use_miopen_channels_last_3d;
|
| 381 |
+
}
|
| 382 |
+
|
| 383 |
+
inline bool mkldnn_conv_use_channels_last(const at::Tensor& input, const at::Tensor& weight) {
|
| 384 |
+
|
| 385 |
+
// disable NHWC for float64 input.
|
| 386 |
+
if (input.scalar_type() == at::kDouble ||
|
| 387 |
+
weight.scalar_type() == at::kDouble) {
|
| 388 |
+
return false;
|
| 389 |
+
}
|
| 390 |
+
|
| 391 |
+
// disable NHWC for MkldnnCPU tensor.
|
| 392 |
+
if (input.is_mkldnn() || weight.is_mkldnn()) {
|
| 393 |
+
return false;
|
| 394 |
+
}
|
| 395 |
+
|
| 396 |
+
auto input_memory_format = input.suggest_memory_format();
|
| 397 |
+
auto weight_memory_format = weight.suggest_memory_format();
|
| 398 |
+
|
| 399 |
+
bool can_use_mkldnn_channels_last_2d =
|
| 400 |
+
(input_memory_format == at::MemoryFormat::ChannelsLast) ||
|
| 401 |
+
(weight_memory_format == at::MemoryFormat::ChannelsLast);
|
| 402 |
+
|
| 403 |
+
bool can_use_mkldnn_channels_last_3d =
|
| 404 |
+
(input_memory_format == at::MemoryFormat::ChannelsLast3d) ||
|
| 405 |
+
(weight_memory_format == at::MemoryFormat::ChannelsLast3d);
|
| 406 |
+
|
| 407 |
+
return can_use_mkldnn_channels_last_2d || can_use_mkldnn_channels_last_3d;
|
| 408 |
+
}
|
| 409 |
+
|
| 410 |
+
inline bool thnn_conv_use_channels_last(const at::Tensor& input, const at::Tensor& weight) {
|
| 411 |
+
|
| 412 |
+
auto input_memory_format = input.suggest_memory_format();
|
| 413 |
+
auto weight_memory_format = weight.suggest_memory_format();
|
| 414 |
+
|
| 415 |
+
bool can_use_thnn_channels_last_2d = input.device().is_cpu() && (
|
| 416 |
+
(input_memory_format == at::MemoryFormat::ChannelsLast) || (
|
| 417 |
+
weight_memory_format == at::MemoryFormat::ChannelsLast));
|
| 418 |
+
|
| 419 |
+
return can_use_thnn_channels_last_2d;
|
| 420 |
+
}
|
| 421 |
+
|
| 422 |
+
inline bool xpu_conv_use_channels_last(const at::Tensor& input, const at::Tensor& weight) {
|
| 423 |
+
|
| 424 |
+
// check layout only for xpu tensor.
|
| 425 |
+
if (!input.is_xpu() || !weight.is_xpu()) {
|
| 426 |
+
return false;
|
| 427 |
+
}
|
| 428 |
+
|
| 429 |
+
// disable NHWC for float64 input.
|
| 430 |
+
if (input.scalar_type() == at::kDouble ||
|
| 431 |
+
weight.scalar_type() == at::kDouble) {
|
| 432 |
+
return false;
|
| 433 |
+
}
|
| 434 |
+
|
| 435 |
+
auto input_memory_format = input.suggest_memory_format();
|
| 436 |
+
auto weight_memory_format = weight.suggest_memory_format();
|
| 437 |
+
|
| 438 |
+
bool can_use_xpu_channels_last_2d =
|
| 439 |
+
(input_memory_format == at::MemoryFormat::ChannelsLast) ||
|
| 440 |
+
(weight_memory_format == at::MemoryFormat::ChannelsLast);
|
| 441 |
+
|
| 442 |
+
bool can_use_xpu_channels_last_3d =
|
| 443 |
+
(input_memory_format == at::MemoryFormat::ChannelsLast3d) ||
|
| 444 |
+
(weight_memory_format == at::MemoryFormat::ChannelsLast3d);
|
| 445 |
+
|
| 446 |
+
return can_use_xpu_channels_last_2d || can_use_xpu_channels_last_3d;
|
| 447 |
+
}
|
| 448 |
+
|
| 449 |
+
} // namespace at::native
|
.venv/lib/python3.11/site-packages/torch/include/ATen/native/ConvolutionMM3d.h
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <ATen/core/Tensor.h>
|
| 2 |
+
|
| 3 |
+
namespace at::native {
|
| 4 |
+
|
| 5 |
+
std::tuple<Tensor, Tensor, Tensor> slow_conv3d_backward_cpu(
|
| 6 |
+
const Tensor& grad_output,
|
| 7 |
+
const Tensor& self,
|
| 8 |
+
const Tensor& weight,
|
| 9 |
+
IntArrayRef kernel_size,
|
| 10 |
+
IntArrayRef stride,
|
| 11 |
+
IntArrayRef padding,
|
| 12 |
+
std::array<bool, 3> output_mask);
|
| 13 |
+
|
| 14 |
+
} // namespace at::native
|
.venv/lib/python3.11/site-packages/torch/include/ATen/native/Copy.h
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/native/DispatchStub.h>
|
| 4 |
+
|
| 5 |
+
namespace at {
|
| 6 |
+
|
| 7 |
+
class Tensor;
|
| 8 |
+
struct TensorIterator;
|
| 9 |
+
class TensorBase;
|
| 10 |
+
|
| 11 |
+
namespace native {
|
| 12 |
+
|
| 13 |
+
using copy_fn = void (*)(TensorIterator&, bool non_blocking);
|
| 14 |
+
|
| 15 |
+
DECLARE_DISPATCH(copy_fn, copy_stub);
|
| 16 |
+
|
| 17 |
+
TORCH_API void copy_ignoring_overlaps(const TensorBase &dst, const TensorBase &src);
|
| 18 |
+
|
| 19 |
+
} // namespace native
|
| 20 |
+
} // namespace at
|
.venv/lib/python3.11/site-packages/torch/include/ATen/native/Cross.h
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/native/DispatchStub.h>
|
| 4 |
+
|
| 5 |
+
namespace at {
|
| 6 |
+
class Tensor;
|
| 7 |
+
|
| 8 |
+
namespace native {
|
| 9 |
+
|
| 10 |
+
using cross_fn = void(*)(const Tensor&, const Tensor&, const Tensor&, const int64_t d);
|
| 11 |
+
|
| 12 |
+
DECLARE_DISPATCH(cross_fn, cross_stub);
|
| 13 |
+
|
| 14 |
+
}} // namespace at::native
|
.venv/lib/python3.11/site-packages/torch/include/ATen/native/DilatedConvolutionUtils.h
ADDED
|
@@ -0,0 +1,229 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <algorithm>
|
| 4 |
+
#include <vector>
|
| 5 |
+
|
| 6 |
+
#include <ATen/div_rtn.h>
|
| 7 |
+
#include <ATen/core/Tensor.h>
|
| 8 |
+
#include <c10/util/irange.h>
|
| 9 |
+
|
| 10 |
+
#define TORCH_CHECK_DIM_SIZE(T, DIM, DIM_SIZE, SIZE) \
|
| 11 |
+
TORCH_CHECK( \
|
| 12 |
+
T.dim() == DIM && T.size(DIM_SIZE) == SIZE, \
|
| 13 |
+
"Need " #T " of dimension ", \
|
| 14 |
+
DIM, \
|
| 15 |
+
" and " #T ".size[", \
|
| 16 |
+
DIM_SIZE, \
|
| 17 |
+
"] == ", \
|
| 18 |
+
SIZE, \
|
| 19 |
+
" but got input to be of shape ", \
|
| 20 |
+
T.sizes())
|
| 21 |
+
|
| 22 |
+
namespace at::native::internal {
|
| 23 |
+
namespace {
|
| 24 |
+
inline bool all_positive(IntArrayRef& arr) {
|
| 25 |
+
return std::all_of(
|
| 26 |
+
arr.begin(), arr.end(), [](int64_t item) { return item > 0; });
|
| 27 |
+
}
|
| 28 |
+
|
| 29 |
+
inline bool all_nonnegative(std::vector<int64_t>& arr) {
|
| 30 |
+
return std::all_of(
|
| 31 |
+
arr.begin(), arr.end(), [](int64_t item) { return item >= 0; });
|
| 32 |
+
}
|
| 33 |
+
|
| 34 |
+
} // namespace
|
| 35 |
+
|
| 36 |
+
// calculate the rear part of output tensor sizes
|
| 37 |
+
template <int64_t dim>
|
| 38 |
+
std::vector<int64_t> get_output_size(
|
| 39 |
+
const Tensor& input,
|
| 40 |
+
IntArrayRef kernel_size,
|
| 41 |
+
IntArrayRef stride_size,
|
| 42 |
+
IntArrayRef pad_size,
|
| 43 |
+
IntArrayRef dilation_size) {
|
| 44 |
+
std::vector<int64_t> sizes;
|
| 45 |
+
for (const auto index : c10::irange(dim)) {
|
| 46 |
+
sizes.push_back(
|
| 47 |
+
div_rtn<int64_t>(
|
| 48 |
+
input.size(index + input.dim() - dim) + 2 * pad_size[index] -
|
| 49 |
+
(dilation_size[index] * (kernel_size[index] - 1) + 1),
|
| 50 |
+
stride_size[index]) +
|
| 51 |
+
1);
|
| 52 |
+
}
|
| 53 |
+
return sizes;
|
| 54 |
+
}
|
| 55 |
+
|
| 56 |
+
// calculate the sizes of output tensor
|
| 57 |
+
template <int64_t dim>
|
| 58 |
+
std::vector<int64_t> get_output_size(
|
| 59 |
+
const Tensor& input,
|
| 60 |
+
const Tensor& weight,
|
| 61 |
+
IntArrayRef kernel_size,
|
| 62 |
+
IntArrayRef stride_size,
|
| 63 |
+
IntArrayRef pad_size,
|
| 64 |
+
IntArrayRef dilation_size) {
|
| 65 |
+
auto output_size = get_output_size<dim>(
|
| 66 |
+
input, kernel_size, stride_size, pad_size, dilation_size);
|
| 67 |
+
output_size.insert(output_size.begin(), weight.size(0));
|
| 68 |
+
if (input.dim() == dim + 2) {
|
| 69 |
+
output_size.insert(output_size.begin(), input.size(0));
|
| 70 |
+
}
|
| 71 |
+
return output_size;
|
| 72 |
+
}
|
| 73 |
+
/*
|
| 74 |
+
slow_conv_dilated_shape_check - check user-input to dilated convolution
|
| 75 |
+
forward and backward functions.
|
| 76 |
+
*/
|
| 77 |
+
template <int64_t dim>
|
| 78 |
+
void slow_conv_dilated_shape_check(
|
| 79 |
+
const Tensor& input,
|
| 80 |
+
const Tensor& weight,
|
| 81 |
+
const Tensor& bias,
|
| 82 |
+
const Tensor& grad_output,
|
| 83 |
+
IntArrayRef kernel_size,
|
| 84 |
+
IntArrayRef stride_size,
|
| 85 |
+
IntArrayRef pad_size,
|
| 86 |
+
IntArrayRef dilation_size) {
|
| 87 |
+
/*
|
| 88 |
+
When the following tensors are defined:
|
| 89 |
+
|
| 90 |
+
bias, grad_weight, grad_output
|
| 91 |
+
|
| 92 |
+
then these are assumed to be contiguous without checking
|
| 93 |
+
because of these tensors are made contiguous by calling
|
| 94 |
+
.contiguous() method or by resizing of zero-sized tensors in
|
| 95 |
+
forward/backward functions.
|
| 96 |
+
|
| 97 |
+
When grad_weight is defined then it is assumed without
|
| 98 |
+
checking to have the same shape as weight, see backward
|
| 99 |
+
functions.
|
| 100 |
+
*/
|
| 101 |
+
// Check size arguments
|
| 102 |
+
TORCH_CHECK(
|
| 103 |
+
kernel_size.size() == dim,
|
| 104 |
+
"kernel sizes length should be ",
|
| 105 |
+
dim,
|
| 106 |
+
", but got ",
|
| 107 |
+
kernel_size.size());
|
| 108 |
+
TORCH_CHECK(
|
| 109 |
+
stride_size.size() == dim,
|
| 110 |
+
"strides length should be ",
|
| 111 |
+
dim,
|
| 112 |
+
", but got ",
|
| 113 |
+
stride_size.size());
|
| 114 |
+
TORCH_CHECK(
|
| 115 |
+
dilation_size.size() == dim,
|
| 116 |
+
"dilations length should be ",
|
| 117 |
+
dim,
|
| 118 |
+
", but got ",
|
| 119 |
+
dilation_size.size());
|
| 120 |
+
TORCH_CHECK(
|
| 121 |
+
pad_size.size() == dim,
|
| 122 |
+
"pads length should be ",
|
| 123 |
+
dim,
|
| 124 |
+
", but got ",
|
| 125 |
+
pad_size.size());
|
| 126 |
+
|
| 127 |
+
TORCH_CHECK(
|
| 128 |
+
all_positive(kernel_size),
|
| 129 |
+
"kernel size should be greater than zero, but got ",
|
| 130 |
+
kernel_size);
|
| 131 |
+
TORCH_CHECK(
|
| 132 |
+
all_positive(stride_size),
|
| 133 |
+
"stride should be greater than zero, but got ",
|
| 134 |
+
stride_size);
|
| 135 |
+
TORCH_CHECK(
|
| 136 |
+
all_positive(dilation_size),
|
| 137 |
+
"dilation should be greater than zero, but got ",
|
| 138 |
+
dilation_size);
|
| 139 |
+
|
| 140 |
+
// check input
|
| 141 |
+
TORCH_CHECK(input.defined(), "input must be defined");
|
| 142 |
+
bool is_batch = input.dim() == dim + 2;
|
| 143 |
+
int64_t n = (is_batch ? 2 : 1);
|
| 144 |
+
int64_t ndim = n + dim;
|
| 145 |
+
if (!is_batch) {
|
| 146 |
+
// input dim has to be dim + 1 if not batched
|
| 147 |
+
TORCH_CHECK(
|
| 148 |
+
input.dim() == dim + 1,
|
| 149 |
+
"input must be 4D or 5D tensor but got ",
|
| 150 |
+
input.dim(),
|
| 151 |
+
"D tensor");
|
| 152 |
+
}
|
| 153 |
+
|
| 154 |
+
// check output sizes
|
| 155 |
+
auto output_size = get_output_size<dim>(
|
| 156 |
+
input, kernel_size, stride_size, pad_size, dilation_size);
|
| 157 |
+
|
| 158 |
+
TORCH_CHECK(
|
| 159 |
+
all_nonnegative(output_size),
|
| 160 |
+
"calculated output size ",
|
| 161 |
+
output_size,
|
| 162 |
+
" is too small (all sizes must be non-negative)");
|
| 163 |
+
|
| 164 |
+
// check weight
|
| 165 |
+
TORCH_CHECK(weight.defined(), "weight must be defined");
|
| 166 |
+
TORCH_CHECK(
|
| 167 |
+
weight.dim() == dim + 2,
|
| 168 |
+
"weight must be ",
|
| 169 |
+
dim + 2,
|
| 170 |
+
"D tensor but got ",
|
| 171 |
+
weight.dim(),
|
| 172 |
+
"D tensor dim=",
|
| 173 |
+
dim);
|
| 174 |
+
TORCH_CHECK(
|
| 175 |
+
weight.sizes().slice(2) == kernel_size,
|
| 176 |
+
"weight[2:] shape ",
|
| 177 |
+
weight.sizes().slice(2),
|
| 178 |
+
" must be equal to kernel_size ",
|
| 179 |
+
kernel_size);
|
| 180 |
+
|
| 181 |
+
TORCH_CHECK_DIM_SIZE(input, input.dim(), (is_batch ? 1 : 0), weight.size(1));
|
| 182 |
+
|
| 183 |
+
// check bias when present
|
| 184 |
+
if (bias.defined()) {
|
| 185 |
+
TORCH_CHECK(
|
| 186 |
+
bias.dim() == 1,
|
| 187 |
+
"bias must be 1D tensor but got ",
|
| 188 |
+
bias.dim(),
|
| 189 |
+
"D tensor");
|
| 190 |
+
TORCH_CHECK_DIM_SIZE(bias, 1, 0, weight.size(0));
|
| 191 |
+
}
|
| 192 |
+
|
| 193 |
+
// check grad_output when present
|
| 194 |
+
if (grad_output.defined()) {
|
| 195 |
+
TORCH_CHECK(
|
| 196 |
+
grad_output.dim() == ndim,
|
| 197 |
+
"grad_output must be ",
|
| 198 |
+
ndim,
|
| 199 |
+
"D tensor but got ",
|
| 200 |
+
grad_output.dim(),
|
| 201 |
+
"D tensor");
|
| 202 |
+
if (is_batch) {
|
| 203 |
+
TORCH_CHECK(
|
| 204 |
+
grad_output.size(0) == input.size(0),
|
| 205 |
+
"grad_output.size(0)=",
|
| 206 |
+
grad_output.size(0),
|
| 207 |
+
" must be input.size(0)=",
|
| 208 |
+
input.size(0));
|
| 209 |
+
}
|
| 210 |
+
TORCH_CHECK(
|
| 211 |
+
grad_output.size(n - 1) == weight.size(0),
|
| 212 |
+
"grad_output.size(",
|
| 213 |
+
n - 1,
|
| 214 |
+
")=",
|
| 215 |
+
grad_output.size(n - 1),
|
| 216 |
+
" must be weight.size(0)=",
|
| 217 |
+
weight.size(0));
|
| 218 |
+
TORCH_CHECK(
|
| 219 |
+
grad_output.sizes().slice(n) == output_size,
|
| 220 |
+
"grad_output[",
|
| 221 |
+
n,
|
| 222 |
+
":] shape",
|
| 223 |
+
grad_output.sizes().slice(n),
|
| 224 |
+
" must be equal to output size ",
|
| 225 |
+
output_size);
|
| 226 |
+
}
|
| 227 |
+
}
|
| 228 |
+
|
| 229 |
+
} // namespace at::native::internal
|
.venv/lib/python3.11/site-packages/torch/include/ATen/native/DispatchStub.h
ADDED
|
@@ -0,0 +1,444 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <c10/core/DeviceType.h>
|
| 4 |
+
#include <c10/macros/Macros.h>
|
| 5 |
+
#include <c10/util/Array.h>
|
| 6 |
+
|
| 7 |
+
#include <atomic>
|
| 8 |
+
#include <utility>
|
| 9 |
+
#include <variant>
|
| 10 |
+
|
| 11 |
+
// Implements instruction set specific function dispatch.
|
| 12 |
+
//
|
| 13 |
+
// Kernels that may make use of specialized instruction sets (e.g. AVX2) are
|
| 14 |
+
// compiled multiple times with different compiler flags (e.g. -mavx2). A
|
| 15 |
+
// DispatchStub contains a table of function pointers for a kernel. At runtime,
|
| 16 |
+
// the fastest available kernel is chosen based on the features reported by
|
| 17 |
+
// cpuinfo.
|
| 18 |
+
//
|
| 19 |
+
// Example:
|
| 20 |
+
//
|
| 21 |
+
// In native/MyKernel.h:
|
| 22 |
+
// using fn_type = void(*)(const Tensor& x);
|
| 23 |
+
// DECLARE_DISPATCH(fn_type, stub);
|
| 24 |
+
//
|
| 25 |
+
// In native/MyKernel.cpp
|
| 26 |
+
// DEFINE_DISPATCH(stub);
|
| 27 |
+
//
|
| 28 |
+
// In native/cpu/MyKernel.cpp:
|
| 29 |
+
// namespace {
|
| 30 |
+
// // use anonymous namespace so that different cpu versions won't conflict
|
| 31 |
+
// void kernel(const Tensor& x) { ... }
|
| 32 |
+
// }
|
| 33 |
+
// REGISTER_DISPATCH(stub, &kernel);
|
| 34 |
+
//
|
| 35 |
+
// To call:
|
| 36 |
+
// stub(kCPU, tensor);
|
| 37 |
+
//
|
| 38 |
+
// TODO: CPU instruction set selection should be folded into whatever
|
| 39 |
+
// the main dispatch mechanism is.
|
| 40 |
+
//
|
| 41 |
+
// Supported device types for registration:
|
| 42 |
+
// - CPU: Central Processing Unit
|
| 43 |
+
// - CUDA: NVIDIA GPUs
|
| 44 |
+
// - HIP: AMD GPUs
|
| 45 |
+
// - MPS: Apple Silicon GPUs (Metal Performance Shaders)
|
| 46 |
+
// - MTIA: Meta Training and Inference Devices
|
| 47 |
+
// - XPU: Intel GPUs
|
| 48 |
+
// - PrivateUse1: Reserved for private/custom device types
|
| 49 |
+
//
|
| 50 |
+
// If you want to update the list of supported devices, add a new dispatch_ptr
|
| 51 |
+
// member in DispatchStubImpl.h and update the get_call_ptr switch.
|
| 52 |
+
// As well you will need to update the inlined list in 'is_device_supported`
|
| 53 |
+
//
|
| 54 |
+
//
|
| 55 |
+
// ignore warnings about DispatchStub::DEFAULT, AVX, AVX2 defined elsewhere
|
| 56 |
+
C10_CLANG_DIAGNOSTIC_PUSH()
|
| 57 |
+
C10_CLANG_DIAGNOSTIC_IGNORE("-Wundefined-var-template")
|
| 58 |
+
|
| 59 |
+
namespace at::native {
|
| 60 |
+
|
| 61 |
+
enum class CPUCapability {
|
| 62 |
+
DEFAULT = 0,
|
| 63 |
+
#if defined(HAVE_VSX_CPU_DEFINITION)
|
| 64 |
+
VSX = 1,
|
| 65 |
+
#elif defined(HAVE_ZVECTOR_CPU_DEFINITION)
|
| 66 |
+
ZVECTOR = 1,
|
| 67 |
+
#else
|
| 68 |
+
AVX2 = 1,
|
| 69 |
+
AVX512 = 2,
|
| 70 |
+
#endif
|
| 71 |
+
NUM_OPTIONS
|
| 72 |
+
};
|
| 73 |
+
|
| 74 |
+
// Enum for error types
|
| 75 |
+
enum class ErrorType {
|
| 76 |
+
MissingDeviceKernel,
|
| 77 |
+
DeviceNotSupported
|
| 78 |
+
};
|
| 79 |
+
|
| 80 |
+
// Alias for the return type using std::variant
|
| 81 |
+
using DispatchResult = std::variant<void*, ErrorType>;
|
| 82 |
+
|
| 83 |
+
CPUCapability get_cpu_capability();
|
| 84 |
+
|
| 85 |
+
template <typename FnPtr, typename T>
|
| 86 |
+
struct DispatchStub;
|
| 87 |
+
|
| 88 |
+
/**
|
| 89 |
+
* The sole purpose of this class is to outline methods that don't need to be
|
| 90 |
+
* specialized or otherwise inlined and duplicated (by the compiler due to
|
| 91 |
+
* template expansion), since it causes size bloat if there are a significant
|
| 92 |
+
* number of specialization of the DispatchStub<> class.
|
| 93 |
+
*/
|
| 94 |
+
struct TORCH_API DispatchStubImpl {
|
| 95 |
+
|
| 96 |
+
// The DispatchStubImpl::try_get_call_ptr() method is used to get the call
|
| 97 |
+
// pointer for a given device type. If the call pointer is not found,
|
| 98 |
+
// DispatchStubImpl::try_get_call_ptr() returns an ErrorType.
|
| 99 |
+
// The main difference between try_get_call_ptr() and get_call_ptr() is that
|
| 100 |
+
// try_get_call_ptr() will return the ErrorType and not raise an exception.
|
| 101 |
+
DispatchResult try_get_call_ptr(
|
| 102 |
+
c10::DeviceType device_type
|
| 103 |
+
, void *DEFAULT
|
| 104 |
+
#ifdef HAVE_AVX512_CPU_DEFINITION
|
| 105 |
+
, void *AVX512
|
| 106 |
+
#endif
|
| 107 |
+
#ifdef HAVE_AVX2_CPU_DEFINITION
|
| 108 |
+
, void *AVX2
|
| 109 |
+
#endif
|
| 110 |
+
#ifdef HAVE_VSX_CPU_DEFINITION
|
| 111 |
+
, void *VSX
|
| 112 |
+
#endif
|
| 113 |
+
#ifdef HAVE_ZVECTOR_CPU_DEFINITION
|
| 114 |
+
, void *ZVECTOR
|
| 115 |
+
#endif
|
| 116 |
+
);
|
| 117 |
+
|
| 118 |
+
// Analogous to try_get_call_ptr(), but it will return the ErrorType and not
|
| 119 |
+
// raise an exception.
|
| 120 |
+
DispatchResult try_choose_cpu_impl(
|
| 121 |
+
void *DEFAULT
|
| 122 |
+
#ifdef HAVE_AVX512_CPU_DEFINITION
|
| 123 |
+
, void *AVX512
|
| 124 |
+
#endif
|
| 125 |
+
#ifdef HAVE_AVX2_CPU_DEFINITION
|
| 126 |
+
, void *AVX2
|
| 127 |
+
#endif
|
| 128 |
+
#ifdef HAVE_VSX_CPU_DEFINITION
|
| 129 |
+
, void *VSX
|
| 130 |
+
#endif
|
| 131 |
+
#ifdef HAVE_ZVECTOR_CPU_DEFINITION
|
| 132 |
+
, void *ZVECTOR
|
| 133 |
+
#endif
|
| 134 |
+
);
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
void* get_call_ptr(
|
| 138 |
+
c10::DeviceType device_type
|
| 139 |
+
, void *DEFAULT
|
| 140 |
+
#ifdef HAVE_AVX512_CPU_DEFINITION
|
| 141 |
+
, void *AVX512
|
| 142 |
+
#endif
|
| 143 |
+
#ifdef HAVE_AVX2_CPU_DEFINITION
|
| 144 |
+
, void *AVX2
|
| 145 |
+
#endif
|
| 146 |
+
#ifdef HAVE_VSX_CPU_DEFINITION
|
| 147 |
+
, void *VSX
|
| 148 |
+
#endif
|
| 149 |
+
#ifdef HAVE_ZVECTOR_CPU_DEFINITION
|
| 150 |
+
, void *ZVECTOR
|
| 151 |
+
#endif
|
| 152 |
+
);
|
| 153 |
+
|
| 154 |
+
/**
|
| 155 |
+
* The CPU Dispatch actual method is chosen in decreasing order of preference by
|
| 156 |
+
* DispatchStubImpl::choose_cpu_impl() in case none is found by
|
| 157 |
+
* DispatchStubImpl::get_call_ptr() in cpu_dispatch_ptr.
|
| 158 |
+
*/
|
| 159 |
+
void* choose_cpu_impl(
|
| 160 |
+
void *DEFAULT
|
| 161 |
+
#ifdef HAVE_AVX512_CPU_DEFINITION
|
| 162 |
+
, void *AVX512
|
| 163 |
+
#endif
|
| 164 |
+
#ifdef HAVE_AVX2_CPU_DEFINITION
|
| 165 |
+
, void *AVX2
|
| 166 |
+
#endif
|
| 167 |
+
#ifdef HAVE_VSX_CPU_DEFINITION
|
| 168 |
+
, void *VSX
|
| 169 |
+
#endif
|
| 170 |
+
#ifdef HAVE_ZVECTOR_CPU_DEFINITION
|
| 171 |
+
, void *ZVECTOR
|
| 172 |
+
#endif
|
| 173 |
+
);
|
| 174 |
+
|
| 175 |
+
// Fixing dispatch error in Windows debug builds.
|
| 176 |
+
// See https://github.com/pytorch/pytorch/issues/22681 for more details.
|
| 177 |
+
#if defined(_MSC_VER) && defined(_DEBUG)
|
| 178 |
+
std::atomic<void*> cpu_dispatch_ptr;
|
| 179 |
+
void* cuda_dispatch_ptr;
|
| 180 |
+
void* hip_dispatch_ptr;
|
| 181 |
+
void* mps_dispatch_ptr;
|
| 182 |
+
void* mtia_dispatch_ptr;
|
| 183 |
+
#if defined(USE_XPU)
|
| 184 |
+
void* xpu_dispatch_ptr;
|
| 185 |
+
#endif
|
| 186 |
+
void* privateuse1_dispatch_ptr;
|
| 187 |
+
#else
|
| 188 |
+
std::atomic<void*> cpu_dispatch_ptr{nullptr};
|
| 189 |
+
void* cuda_dispatch_ptr = nullptr;
|
| 190 |
+
void* hip_dispatch_ptr = nullptr;
|
| 191 |
+
void* mps_dispatch_ptr = nullptr;
|
| 192 |
+
void* mtia_dispatch_ptr = nullptr;
|
| 193 |
+
#if defined(USE_XPU)
|
| 194 |
+
void* xpu_dispatch_ptr = nullptr;
|
| 195 |
+
#endif
|
| 196 |
+
void* privateuse1_dispatch_ptr = nullptr;
|
| 197 |
+
#endif
|
| 198 |
+
};
|
| 199 |
+
|
| 200 |
+
template <typename rT, typename T, typename... Args>
|
| 201 |
+
struct DispatchStub<rT (*)(Args...), T> {
|
| 202 |
+
using FnPtr = rT (*) (Args...);
|
| 203 |
+
|
| 204 |
+
DispatchStub() = default;
|
| 205 |
+
DispatchStub(const DispatchStub&) = delete;
|
| 206 |
+
DispatchStub& operator=(const DispatchStub&) = delete;
|
| 207 |
+
|
| 208 |
+
private:
|
| 209 |
+
FnPtr get_call_ptr(const c10::DeviceType device_type) {
|
| 210 |
+
return reinterpret_cast<FnPtr>(
|
| 211 |
+
impl.get_call_ptr(device_type
|
| 212 |
+
, reinterpret_cast<void*>(DEFAULT)
|
| 213 |
+
#ifdef HAVE_AVX512_CPU_DEFINITION
|
| 214 |
+
, reinterpret_cast<void*>(AVX512)
|
| 215 |
+
#endif
|
| 216 |
+
#ifdef HAVE_AVX2_CPU_DEFINITION
|
| 217 |
+
, reinterpret_cast<void*>(AVX2)
|
| 218 |
+
#endif
|
| 219 |
+
#ifdef HAVE_VSX_CPU_DEFINITION
|
| 220 |
+
, reinterpret_cast<void*>(VSX)
|
| 221 |
+
#endif
|
| 222 |
+
#ifdef HAVE_ZVECTOR_CPU_DEFINITION
|
| 223 |
+
, reinterpret_cast<void*>(ZVECTOR)
|
| 224 |
+
#endif
|
| 225 |
+
)
|
| 226 |
+
);
|
| 227 |
+
}
|
| 228 |
+
|
| 229 |
+
public:
|
| 230 |
+
template <typename... ArgTypes>
|
| 231 |
+
rT operator()(c10::DeviceType device_type, ArgTypes&&... args) {
|
| 232 |
+
FnPtr call_ptr = get_call_ptr(device_type);
|
| 233 |
+
return (*call_ptr)(std::forward<ArgTypes>(args)...);
|
| 234 |
+
}
|
| 235 |
+
|
| 236 |
+
void set_cuda_dispatch_ptr(FnPtr fn_ptr) {
|
| 237 |
+
impl.cuda_dispatch_ptr = reinterpret_cast<void*>(fn_ptr);
|
| 238 |
+
}
|
| 239 |
+
|
| 240 |
+
#if defined(USE_XPU)
|
| 241 |
+
void set_xpu_dispatch_ptr(FnPtr fn_ptr){
|
| 242 |
+
impl.xpu_dispatch_ptr = reinterpret_cast<void*>(fn_ptr);
|
| 243 |
+
}
|
| 244 |
+
#endif
|
| 245 |
+
|
| 246 |
+
void set_hip_dispatch_ptr(FnPtr fn_ptr) {
|
| 247 |
+
impl.hip_dispatch_ptr = reinterpret_cast<void*>(fn_ptr);
|
| 248 |
+
}
|
| 249 |
+
|
| 250 |
+
void set_mps_dispatch_ptr(FnPtr fn_ptr) {
|
| 251 |
+
impl.mps_dispatch_ptr = reinterpret_cast<void*>(fn_ptr);
|
| 252 |
+
}
|
| 253 |
+
|
| 254 |
+
void set_mtia_dispatch_ptr(FnPtr fn_ptr) {
|
| 255 |
+
impl.mtia_dispatch_ptr = reinterpret_cast<void*>(fn_ptr);
|
| 256 |
+
}
|
| 257 |
+
|
| 258 |
+
void set_privateuse1_dispatch_ptr(FnPtr fn_ptr) {
|
| 259 |
+
impl.privateuse1_dispatch_ptr = reinterpret_cast<void*>(fn_ptr);
|
| 260 |
+
}
|
| 261 |
+
|
| 262 |
+
// Returns true if the dispatcher has a kernel registered for this device
|
| 263 |
+
// type.
|
| 264 |
+
bool is_device_supported(const c10::DeviceType device_type) {
|
| 265 |
+
auto result = impl.try_get_call_ptr(device_type
|
| 266 |
+
, reinterpret_cast<void*>(DEFAULT)
|
| 267 |
+
#ifdef HAVE_AVX512_CPU_DEFINITION
|
| 268 |
+
, reinterpret_cast<void*>(AVX512)
|
| 269 |
+
#endif
|
| 270 |
+
#ifdef HAVE_AVX2_CPU_DEFINITION
|
| 271 |
+
, reinterpret_cast<void*>(AVX2)
|
| 272 |
+
#endif
|
| 273 |
+
#ifdef HAVE_VSX_CPU_DEFINITION
|
| 274 |
+
, reinterpret_cast<void*>(VSX)
|
| 275 |
+
#endif
|
| 276 |
+
#ifdef HAVE_ZVECTOR_CPU_DEFINITION
|
| 277 |
+
, reinterpret_cast<void*>(ZVECTOR)
|
| 278 |
+
#endif
|
| 279 |
+
);
|
| 280 |
+
if (std::holds_alternative<ErrorType>(result)){
|
| 281 |
+
return false;
|
| 282 |
+
}
|
| 283 |
+
return true;
|
| 284 |
+
};
|
| 285 |
+
|
| 286 |
+
static TORCH_API FnPtr DEFAULT;
|
| 287 |
+
#ifdef HAVE_AVX512_CPU_DEFINITION
|
| 288 |
+
static TORCH_API FnPtr AVX512;
|
| 289 |
+
#endif
|
| 290 |
+
#ifdef HAVE_AVX2_CPU_DEFINITION
|
| 291 |
+
static TORCH_API FnPtr AVX2;
|
| 292 |
+
#endif
|
| 293 |
+
#ifdef HAVE_VSX_CPU_DEFINITION
|
| 294 |
+
static TORCH_API FnPtr VSX;
|
| 295 |
+
#endif
|
| 296 |
+
#ifdef HAVE_ZVECTOR_CPU_DEFINITION
|
| 297 |
+
static TORCH_API FnPtr ZVECTOR;
|
| 298 |
+
#endif
|
| 299 |
+
private:
|
| 300 |
+
DispatchStubImpl impl;
|
| 301 |
+
};
|
| 302 |
+
|
| 303 |
+
namespace {
|
| 304 |
+
template <typename DispatchStub>
|
| 305 |
+
struct RegisterCUDADispatch {
|
| 306 |
+
RegisterCUDADispatch(DispatchStub &stub, typename DispatchStub::FnPtr value) {
|
| 307 |
+
stub.set_cuda_dispatch_ptr(value);
|
| 308 |
+
}
|
| 309 |
+
};
|
| 310 |
+
|
| 311 |
+
template <typename DispatchStub>
|
| 312 |
+
struct RegisterXPUDispatch {
|
| 313 |
+
RegisterXPUDispatch(DispatchStub &stub, typename DispatchStub::FnPtr value){
|
| 314 |
+
stub.set_xpu_dispatch_ptr(value);
|
| 315 |
+
}
|
| 316 |
+
};
|
| 317 |
+
|
| 318 |
+
template <typename DispatchStub>
|
| 319 |
+
struct RegisterMPSDispatch {
|
| 320 |
+
RegisterMPSDispatch(DispatchStub &stub, typename DispatchStub::FnPtr value) {
|
| 321 |
+
stub.set_mps_dispatch_ptr(value);
|
| 322 |
+
}
|
| 323 |
+
};
|
| 324 |
+
|
| 325 |
+
template <typename DispatchStub>
|
| 326 |
+
struct RegisterHIPDispatch {
|
| 327 |
+
RegisterHIPDispatch(DispatchStub &stub, typename DispatchStub::FnPtr value) {
|
| 328 |
+
// TODO: make this point at hip_dispatch_ptr
|
| 329 |
+
stub.set_cuda_dispatch_ptr(value);
|
| 330 |
+
}
|
| 331 |
+
};
|
| 332 |
+
|
| 333 |
+
template <typename DispatchStub>
|
| 334 |
+
struct RegisterMTIADispatch {
|
| 335 |
+
RegisterMTIADispatch(DispatchStub &stub, typename DispatchStub::FnPtr value) {
|
| 336 |
+
stub.set_mtia_dispatch_ptr(value);
|
| 337 |
+
}
|
| 338 |
+
};
|
| 339 |
+
|
| 340 |
+
template <typename DispatchStub>
|
| 341 |
+
struct RegisterPRIVATEUSE1Dispatch {
|
| 342 |
+
RegisterPRIVATEUSE1Dispatch(DispatchStub &stub, typename DispatchStub::FnPtr value) {
|
| 343 |
+
stub.set_privateuse1_dispatch_ptr(value);
|
| 344 |
+
}
|
| 345 |
+
};
|
| 346 |
+
|
| 347 |
+
} // anonymous namespace
|
| 348 |
+
// Compiler will complain if you put things like std::tuple<Tensor, Tensor> in
|
| 349 |
+
// the `fn` argument of DECLARE_DISPATCH. Some possible workarounds, e.g.,
|
| 350 |
+
// adding parentheses and using helper struct to get rid of the parentheses, do
|
| 351 |
+
// not work with MSVC. So do a `using`-declaration if you need to pass in such
|
| 352 |
+
// `fn`, e.g., grid_sampler_2d_backward_cpu_kernel in GridSampleKernel.h.
|
| 353 |
+
#define DECLARE_DISPATCH(fn, name) \
|
| 354 |
+
struct name##_DECLARE_DISPATCH_type : DispatchStub<fn, name##_DECLARE_DISPATCH_type> { \
|
| 355 |
+
name##_DECLARE_DISPATCH_type() = default; \
|
| 356 |
+
name##_DECLARE_DISPATCH_type(const name##_DECLARE_DISPATCH_type&) = delete; \
|
| 357 |
+
name##_DECLARE_DISPATCH_type& operator=(const name##_DECLARE_DISPATCH_type&) = delete; \
|
| 358 |
+
}; \
|
| 359 |
+
extern TORCH_API struct name##_DECLARE_DISPATCH_type name;
|
| 360 |
+
|
| 361 |
+
#define DEFINE_DISPATCH(name) struct name##_DECLARE_DISPATCH_type name
|
| 362 |
+
|
| 363 |
+
#define REGISTER_ARCH_DISPATCH(name, arch, fn) \
|
| 364 |
+
template <> name##_DECLARE_DISPATCH_type::FnPtr TORCH_API DispatchStub<name##_DECLARE_DISPATCH_type::FnPtr, struct name##_DECLARE_DISPATCH_type>::arch = fn;
|
| 365 |
+
|
| 366 |
+
#ifdef HAVE_AVX512_CPU_DEFINITION
|
| 367 |
+
#define REGISTER_AVX512_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, AVX512, fn)
|
| 368 |
+
#else
|
| 369 |
+
#define REGISTER_AVX512_DISPATCH(name, fn)
|
| 370 |
+
#endif
|
| 371 |
+
|
| 372 |
+
#ifdef HAVE_AVX2_CPU_DEFINITION
|
| 373 |
+
#define REGISTER_AVX2_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, AVX2, fn)
|
| 374 |
+
#else
|
| 375 |
+
#define REGISTER_AVX2_DISPATCH(name, fn)
|
| 376 |
+
#endif
|
| 377 |
+
|
| 378 |
+
#ifdef HAVE_VSX_CPU_DEFINITION
|
| 379 |
+
#define REGISTER_VSX_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, VSX, fn)
|
| 380 |
+
#else
|
| 381 |
+
#define REGISTER_VSX_DISPATCH(name, fn)
|
| 382 |
+
#endif
|
| 383 |
+
|
| 384 |
+
#ifdef HAVE_ZVECTOR_CPU_DEFINITION
|
| 385 |
+
#define REGISTER_ZVECTOR_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, ZVECTOR, fn)
|
| 386 |
+
#else
|
| 387 |
+
#define REGISTER_ZVECTOR_DISPATCH(name, fn)
|
| 388 |
+
#endif
|
| 389 |
+
|
| 390 |
+
// Macro to register the same kernel for all CPU arch types. This is useful
|
| 391 |
+
// if a kernel does not benefit from being recompiled across different arch types.
|
| 392 |
+
#define REGISTER_ALL_CPU_DISPATCH(name, fn) \
|
| 393 |
+
REGISTER_ARCH_DISPATCH(name, DEFAULT, fn) \
|
| 394 |
+
REGISTER_AVX512_DISPATCH(name, fn) \
|
| 395 |
+
REGISTER_AVX2_DISPATCH(name, fn) \
|
| 396 |
+
REGISTER_VSX_DISPATCH(name, fn) \
|
| 397 |
+
REGISTER_ZVECTOR_DISPATCH(name, fn)
|
| 398 |
+
|
| 399 |
+
#define REGISTER_NO_CPU_DISPATCH(name) \
|
| 400 |
+
REGISTER_ALL_CPU_DISPATCH(name, nullptr)
|
| 401 |
+
|
| 402 |
+
#define REGISTER_CUDA_DISPATCH(name, fn) \
|
| 403 |
+
static RegisterCUDADispatch<struct name##_DECLARE_DISPATCH_type> name ## __register(name, fn);
|
| 404 |
+
|
| 405 |
+
#define REGISTER_XPU_DISPATCH(name, fn) \
|
| 406 |
+
static RegisterXPUDispatch<struct name##_DECLARE_DISPATCH_type> name ## __register(name, fn);
|
| 407 |
+
|
| 408 |
+
#define REGISTER_HIP_DISPATCH(name, fn) \
|
| 409 |
+
static RegisterHIPDispatch<struct name##_DECLARE_DISPATCH_type> name ## __register(name, fn);
|
| 410 |
+
|
| 411 |
+
#define REGISTER_MPS_DISPATCH(name, fn) \
|
| 412 |
+
static RegisterMPSDispatch<struct name##_DECLARE_DISPATCH_type> name ## __register(name, fn);
|
| 413 |
+
|
| 414 |
+
#define REGISTER_MTIA_DISPATCH(name, fn) \
|
| 415 |
+
static RegisterMTIADispatch<struct name##_DECLARE_DISPATCH_type> name ## __register(name, fn);
|
| 416 |
+
|
| 417 |
+
#define REGISTER_PRIVATEUSE1_DISPATCH(name, fn) \
|
| 418 |
+
static RegisterPRIVATEUSE1Dispatch<struct name##_DECLARE_DISPATCH_type> name ## __register(name, fn);
|
| 419 |
+
|
| 420 |
+
// NB: This macro must be used in an actual 'cu' file; if you try using
|
| 421 |
+
// it from a 'cpp' file it will not work!
|
| 422 |
+
#if defined(__CUDACC__)
|
| 423 |
+
#define REGISTER_DISPATCH(name, fn) REGISTER_CUDA_DISPATCH(name, fn)
|
| 424 |
+
#elif defined(__HIPCC__)
|
| 425 |
+
// TODO: cut this over to HIP dispatch once we stop pretending that CUDA
|
| 426 |
+
// is HIP in the PyTorch HIPify build.
|
| 427 |
+
#define REGISTER_DISPATCH(name, fn) REGISTER_CUDA_DISPATCH(name, fn)
|
| 428 |
+
// #define REGISTER_DISPATCH(name, fn) REGISTER_HIP_DISPATCH(name, fn)
|
| 429 |
+
#elif defined(__OBJC__) && defined(USE_MPS)
|
| 430 |
+
// NB: this macro must be used from a 'mm' file in order to dispatch a MPS kernel
|
| 431 |
+
#define REGISTER_DISPATCH(name, fn) REGISTER_MPS_DISPATCH(name, fn)
|
| 432 |
+
#elif defined(CPU_CAPABILITY)
|
| 433 |
+
// REGISTER_DISPATCH now dispatches an AVX512 kernel to nullptr but registers other dispatches.
|
| 434 |
+
// ALSO_REGISTER_AVX512_DISPATCH should be used for ensuring AVX512 dispatch, among others.
|
| 435 |
+
#ifdef CPU_CAPABILITY_AVX512
|
| 436 |
+
#define REGISTER_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, CPU_CAPABILITY, ((void*)(fn) ? nullptr : nullptr))
|
| 437 |
+
#else
|
| 438 |
+
#define REGISTER_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, CPU_CAPABILITY, fn)
|
| 439 |
+
#endif
|
| 440 |
+
#define ALSO_REGISTER_AVX512_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, CPU_CAPABILITY, fn)
|
| 441 |
+
#endif
|
| 442 |
+
} // namespace at::native
|
| 443 |
+
|
| 444 |
+
C10_CLANG_DIAGNOSTIC_POP()
|
.venv/lib/python3.11/site-packages/torch/include/ATen/native/DistributionTemplates.h
ADDED
|
@@ -0,0 +1,394 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/core/Tensor.h>
|
| 4 |
+
#include <ATen/Dispatch.h>
|
| 5 |
+
#include <ATen/Dispatch_v2.h>
|
| 6 |
+
#include <ATen/Generator.h>
|
| 7 |
+
#include <ATen/ExpandUtils.h>
|
| 8 |
+
#include <ATen/Tensor.h>
|
| 9 |
+
#include <ATen/MemoryOverlap.h>
|
| 10 |
+
#include <ATen/NamedTensorUtils.h>
|
| 11 |
+
#include <ATen/native/Resize.h>
|
| 12 |
+
#include <ATen/native/TensorIterator.h>
|
| 13 |
+
#include <cmath>
|
| 14 |
+
#include <limits>
|
| 15 |
+
#include <optional>
|
| 16 |
+
|
| 17 |
+
#ifndef AT_PER_OPERATOR_HEADERS
|
| 18 |
+
#include <ATen/Functions.h>
|
| 19 |
+
#else
|
| 20 |
+
#include <ATen/ops/empty_like.h>
|
| 21 |
+
#include <ATen/ops/empty.h>
|
| 22 |
+
#include <ATen/ops/full.h>
|
| 23 |
+
#include <ATen/ops/view_as_real.h>
|
| 24 |
+
#endif
|
| 25 |
+
|
| 26 |
+
namespace at::native::templates {
|
| 27 |
+
|
| 28 |
+
// ==================================================== Random ========================================================
|
| 29 |
+
|
| 30 |
+
// The purpose of `update_from` and `update_to` is to find the closest valid int64_t number that can be used as actual `from`.
|
| 31 |
+
// The current implementation of `random_` uses uint64_t arithmetics and casts the result to the target dtype(scalar_t).
|
| 32 |
+
// This casting can result in generating numbers that happen to be greater or equal to `to` value. For instance:
|
| 33 |
+
//
|
| 34 |
+
// auto actual = torch::empty({3, 3}, torch::half);
|
| 35 |
+
// actual.random_(0, 65504);
|
| 36 |
+
//
|
| 37 |
+
// If random's uint64_t arithmetics produces 65503 as a random value after casting to torch::half it becomes 65504
|
| 38 |
+
// and violates the requirement that random value must be less than `to`. To resolve this issue `update_from` and `update_to`
|
| 39 |
+
// moves `from` to the right and `to` to the left to the next closest value that won't go outside [from, to) after casting to
|
| 40 |
+
// the target dtype. For `to` = 65504 it moves left for (1 << (log2(to) - 11 + 1)) = 32 and becomes 65472, which is previous
|
| 41 |
+
// available number for torch::half dtype.
|
| 42 |
+
template<typename scalar_t>
|
| 43 |
+
int64_t update_from(int64_t from) {
|
| 44 |
+
static_assert(
|
| 45 |
+
std::is_floating_point<scalar_t>::value ||
|
| 46 |
+
std::is_same<scalar_t, at::Half>::value ||
|
| 47 |
+
std::is_same<scalar_t, at::BFloat16>::value, "scalar_t must be floating-point type");
|
| 48 |
+
const auto from_plus_1 = static_cast<int64_t>(static_cast<scalar_t>(from + 1));
|
| 49 |
+
if (from_plus_1 < from) {
|
| 50 |
+
int64_t from_ = std::abs(from + 1);
|
| 51 |
+
int n = 0;
|
| 52 |
+
while (from_ >>= 1) ++n;
|
| 53 |
+
// NOLINTNEXTLINE(clang-analyzer-core.UndefinedBinaryOperatorResult)
|
| 54 |
+
from = from_plus_1 + (1LL << (n - std::numeric_limits<scalar_t>::digits + 1));
|
| 55 |
+
}
|
| 56 |
+
return from;
|
| 57 |
+
}
|
| 58 |
+
|
| 59 |
+
template<typename scalar_t>
|
| 60 |
+
int64_t update_to(int64_t to) {
|
| 61 |
+
static_assert(
|
| 62 |
+
std::is_floating_point<scalar_t>::value ||
|
| 63 |
+
std::is_same<scalar_t, at::Half>::value ||
|
| 64 |
+
std::is_same<scalar_t, at::BFloat16>::value, "scalar_t must be floating-point type");
|
| 65 |
+
const auto to_minus_1 = static_cast<int64_t>(static_cast<scalar_t>(to - 1));
|
| 66 |
+
if (to_minus_1 >= to) {
|
| 67 |
+
int64_t to_ = std::abs(to - 1);
|
| 68 |
+
int n = 0;
|
| 69 |
+
while (to_ >>= 1) ++n;
|
| 70 |
+
// NOLINTNEXTLINE(clang-analyzer-core.UndefinedBinaryOperatorResult)
|
| 71 |
+
to = to_minus_1 - (1LL << (n - std::numeric_limits<scalar_t>::digits + 1));
|
| 72 |
+
}
|
| 73 |
+
return to;
|
| 74 |
+
}
|
| 75 |
+
|
| 76 |
+
// Return earlier for not invoking kernel.
|
| 77 |
+
// See https://github.com/pytorch/pytorch/issues/103418 for more details
|
| 78 |
+
#define CHECK_EMPTY_AND_RETURN(tensor) \
|
| 79 |
+
if (tensor.numel() == 0) { \
|
| 80 |
+
return tensor; \
|
| 81 |
+
}
|
| 82 |
+
|
| 83 |
+
template<template<typename> class random_kernel, typename RNG>
|
| 84 |
+
at::Tensor& random_impl(at::Tensor& self, std::optional<Generator> generator) {
|
| 85 |
+
CHECK_EMPTY_AND_RETURN(self);
|
| 86 |
+
auto iter = at::TensorIterator::borrowing_nullary_op(self);
|
| 87 |
+
random_kernel<RNG>()(iter, generator);
|
| 88 |
+
return self;
|
| 89 |
+
}
|
| 90 |
+
|
| 91 |
+
#define CHECK_OUT_OF_BOUNDS(var, name, min, max, dtype) \
|
| 92 |
+
TORCH_CHECK(var >= min && var <= max, name , " is out of bounds for ", dtype); \
|
| 93 |
+
|
| 94 |
+
#define WARN_OUT_OF_BOUNDS(var, name, digits, dtype) \
|
| 95 |
+
if (var < -(1LL << digits) || var > (1LL << digits)) { \
|
| 96 |
+
TORCH_WARN(name , " is out of bounds [-(2^", digits, "), 2^", digits, "]. ", \
|
| 97 |
+
"Due to precision limitations ", dtype, " can support discrete uniform distribution only within this range. ", \
|
| 98 |
+
"This warning will become an error in version 1.7 release, please fix the code in advance"); \
|
| 99 |
+
}
|
| 100 |
+
|
| 101 |
+
inline void check_from_to_in_range(int64_t from, int64_t to_inc, caffe2::TypeMeta dtype) {
|
| 102 |
+
const auto scalar_type = typeMetaToScalarType(dtype);
|
| 103 |
+
if (isFloatingType(scalar_type)) {
|
| 104 |
+
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, scalar_type, "check_random_fp_bounds", [&] {
|
| 105 |
+
const auto min = static_cast<double>(std::numeric_limits<scalar_t>::lowest());
|
| 106 |
+
const auto max = static_cast<double>(std::numeric_limits<scalar_t>::max());
|
| 107 |
+
CHECK_OUT_OF_BOUNDS(from, "from", min, max, dtype);
|
| 108 |
+
CHECK_OUT_OF_BOUNDS(to_inc, "to - 1", min, max, dtype);
|
| 109 |
+
|
| 110 |
+
constexpr auto digits = std::numeric_limits<scalar_t>::digits;
|
| 111 |
+
WARN_OUT_OF_BOUNDS(from, "from", digits, dtype);
|
| 112 |
+
WARN_OUT_OF_BOUNDS(to_inc, "to - 1", digits, dtype);
|
| 113 |
+
});
|
| 114 |
+
} else if (scalar_type == kUInt64) {
|
| 115 |
+
// When you do a comparison between int64_t and uint64_t, the usual
|
| 116 |
+
// arithmetic conversions say that the int64_t value is promoted to
|
| 117 |
+
// unsigned. But this conversion wraps around: if I had -1 as my int64_t,
|
| 118 |
+
// then it will promote to 0xFFFFFFFFFFFFFFFF in uint64_t. This is never
|
| 119 |
+
// the right thing to do.
|
| 120 |
+
CHECK_OUT_OF_BOUNDS(from, "from", 0, INT64_MAX, dtype);
|
| 121 |
+
CHECK_OUT_OF_BOUNDS(to_inc, "to - 1", 0, INT64_MAX, dtype);
|
| 122 |
+
} else if (isIntegralType(scalar_type, /*includeBool=*/true)) {
|
| 123 |
+
AT_DISPATCH_V2(scalar_type, "check_random_integral_bounds", AT_WRAP([&]() {
|
| 124 |
+
const auto min = static_cast<int64_t>(std::numeric_limits<scalar_t>::lowest());
|
| 125 |
+
const auto max = static_cast<int64_t>(std::numeric_limits<scalar_t>::max());
|
| 126 |
+
CHECK_OUT_OF_BOUNDS(from, "from", min, max, dtype);
|
| 127 |
+
CHECK_OUT_OF_BOUNDS(to_inc, "to - 1", min, max, dtype);
|
| 128 |
+
}), AT_EXPAND(AT_INTEGRAL_TYPES), kUInt16, kUInt32, kBool);
|
| 129 |
+
} else {
|
| 130 |
+
TORCH_CHECK(false, "check_random_bounds handles only integral, floating-point and boolean types");
|
| 131 |
+
}
|
| 132 |
+
}
|
| 133 |
+
|
| 134 |
+
template<template<typename> class random_from_to_kernel, typename RNG>
|
| 135 |
+
at::Tensor& random_from_to_impl(at::Tensor& self, int64_t from, std::optional<int64_t> to_opt, std::optional<Generator> generator) {
|
| 136 |
+
uint64_t range = 0;
|
| 137 |
+
auto iter = at::TensorIterator::borrowing_nullary_op(self);
|
| 138 |
+
if (to_opt.has_value()) {
|
| 139 |
+
// [from, to)
|
| 140 |
+
int64_t to = *to_opt;
|
| 141 |
+
TORCH_CHECK(from < to, "random_ expects 'from' to be less than 'to', but got from=", from, " >= to=", to);
|
| 142 |
+
if (isFloatingType(iter.dtype())) {
|
| 143 |
+
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, self.scalar_type(), "random_update_from_to", [&] {
|
| 144 |
+
from = update_from<scalar_t>(from);
|
| 145 |
+
to = update_to<scalar_t>(to);
|
| 146 |
+
TORCH_CHECK(from < to, "random_ expects 'from' casted to dtype to be less than 'to' casted to dtype, but got from=", from, " >= to=", to);
|
| 147 |
+
});
|
| 148 |
+
}
|
| 149 |
+
check_from_to_in_range(from, to - 1, self.dtype());
|
| 150 |
+
CHECK_EMPTY_AND_RETURN(self);
|
| 151 |
+
range = static_cast<uint64_t>(to) - static_cast<uint64_t>(from);
|
| 152 |
+
random_from_to_kernel<RNG>()(iter, range, from, generator);
|
| 153 |
+
} else if (from != std::numeric_limits<int64_t>::lowest()) {
|
| 154 |
+
// [from, std::numeric_limits<int64_t>::max()]
|
| 155 |
+
int64_t to_inc = 0;
|
| 156 |
+
if (isFloatingType(iter.dtype())) {
|
| 157 |
+
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, self.scalar_type(), "random_from_to_range_calc", [&] {
|
| 158 |
+
constexpr int64_t scalar_t_max = static_cast<int64_t>(1) << std::numeric_limits<scalar_t>::digits;
|
| 159 |
+
to_inc = scalar_t_max > std::numeric_limits<int64_t>::max() ? std::numeric_limits<int64_t>::max() : static_cast<int64_t>(scalar_t_max);
|
| 160 |
+
from = update_from<scalar_t>(from);
|
| 161 |
+
TORCH_CHECK(from < to_inc, "random_ expects 'from' casted to dtype to be less than or equal to 'to_inc' casted to dtype, but got from=", from, " > to_inc=", to_inc);
|
| 162 |
+
});
|
| 163 |
+
} else if (isIntegralType(iter.dtype(), /*includeBool=*/true)) {
|
| 164 |
+
AT_DISPATCH_V2(self.scalar_type(), "random_from_to_range_calc", AT_WRAP([&] {
|
| 165 |
+
if constexpr (std::is_same_v<scalar_t, bool>) {
|
| 166 |
+
to_inc = static_cast<int64_t>(true);
|
| 167 |
+
} else {
|
| 168 |
+
to_inc = static_cast<int64_t>(std::numeric_limits<scalar_t>::max());
|
| 169 |
+
}
|
| 170 |
+
}), AT_EXPAND(AT_INTEGRAL_TYPES_V2), kBool);
|
| 171 |
+
} else {
|
| 172 |
+
TORCH_CHECK(false, "random_from_to_impl handles only integral, floating-point and boolean types");
|
| 173 |
+
}
|
| 174 |
+
check_from_to_in_range(from, to_inc, self.dtype());
|
| 175 |
+
CHECK_EMPTY_AND_RETURN(self);
|
| 176 |
+
range = static_cast<uint64_t>(to_inc) - static_cast<uint64_t>(from) + 1;
|
| 177 |
+
random_from_to_kernel<RNG>()(iter, range, from, generator);
|
| 178 |
+
} else {
|
| 179 |
+
// [std::numeric_limits<int64_t>::lowest(), std::numeric_limits<int64_t>::max()]
|
| 180 |
+
// range = 2^64
|
| 181 |
+
CHECK_EMPTY_AND_RETURN(self);
|
| 182 |
+
random_from_to_kernel<RNG>()(iter, generator);
|
| 183 |
+
}
|
| 184 |
+
return self;
|
| 185 |
+
}
|
| 186 |
+
|
| 187 |
+
// ==================================================== Normal ========================================================
|
| 188 |
+
|
| 189 |
+
#define CHECK_NORMAL_TENSOR_STD(std) \
|
| 190 |
+
do { \
|
| 191 |
+
TORCH_CHECK( \
|
| 192 |
+
!std.is_complex(), \
|
| 193 |
+
"normal expects standard deviation to be non-complex"); \
|
| 194 |
+
TORCH_CHECK( \
|
| 195 |
+
std.numel() == 0 || std.is_meta() || std.min().ge(0).item<bool>(), \
|
| 196 |
+
"normal expects all elements of std >= 0.0"); \
|
| 197 |
+
} while (0)
|
| 198 |
+
|
| 199 |
+
#define CHECK_NORMAL_STD(std) \
|
| 200 |
+
TORCH_CHECK(std >= 0.0, "normal expects std >= 0.0, but found std ", std);
|
| 201 |
+
|
| 202 |
+
template<template<typename> class normal_kernel, typename RNG>
|
| 203 |
+
Tensor& normal_impl_(Tensor& self, double mean, double std, std::optional<Generator> gen) {
|
| 204 |
+
CHECK_NORMAL_STD(std);
|
| 205 |
+
CHECK_EMPTY_AND_RETURN(self);
|
| 206 |
+
|
| 207 |
+
if (self.is_complex()) {
|
| 208 |
+
auto float_tensor = at::view_as_real(self);
|
| 209 |
+
// variance for normal distribution of the real and imaginary values
|
| 210 |
+
// is half of the input variance
|
| 211 |
+
normal_kernel<RNG>()(float_tensor, mean, std/(std::sqrt(2)), gen);
|
| 212 |
+
} else {
|
| 213 |
+
normal_kernel<RNG>()(self, mean, std, gen);
|
| 214 |
+
}
|
| 215 |
+
return self;
|
| 216 |
+
}
|
| 217 |
+
|
| 218 |
+
template<template<typename> class normal_kernel, typename RNG>
|
| 219 |
+
Tensor& normal_out_impl(Tensor& output, const Tensor& mean, double std, std::optional<Generator> gen) {
|
| 220 |
+
CHECK_NORMAL_STD(std);
|
| 221 |
+
auto std_tensor = at::empty_like(output, MemoryFormat::Contiguous);
|
| 222 |
+
auto shape = at::infer_size(mean.sizes(), std_tensor.sizes());
|
| 223 |
+
at::native::resize_output(output, shape);
|
| 224 |
+
normal_impl_<normal_kernel, RNG>(output, 0, std, gen);
|
| 225 |
+
output.add_(mean);
|
| 226 |
+
return output;
|
| 227 |
+
}
|
| 228 |
+
|
| 229 |
+
template<template<typename> class normal_kernel, typename RNG>
|
| 230 |
+
Tensor& normal_out_impl(Tensor& output, double mean, const Tensor& std, std::optional<Generator> gen) {
|
| 231 |
+
CHECK_NORMAL_TENSOR_STD(std);
|
| 232 |
+
auto mean_tensor = at::full({}, mean, output.options());
|
| 233 |
+
auto shape = at::infer_size(mean_tensor.sizes(), std.sizes());
|
| 234 |
+
at::native::resize_output(output, shape);
|
| 235 |
+
normal_impl_<normal_kernel, RNG>(output, 0, 1, gen);
|
| 236 |
+
// CUDA NB: addcmul_out copies the tensor to be added into the output.
|
| 237 |
+
// The previous function here was addcmul_out(output, mean_tensor, output, std, 1);
|
| 238 |
+
// The third argument is not a constant reference and hence the samples in output are overwritten.
|
| 239 |
+
// Consequently, the computation performed is mean_tensor + mean_tensor * std instead of mean_tensor + output * std
|
| 240 |
+
output.mul_(std).add_(mean_tensor);
|
| 241 |
+
return output;
|
| 242 |
+
}
|
| 243 |
+
|
| 244 |
+
template<template<typename> class normal_kernel, typename RNG>
|
| 245 |
+
Tensor& normal_out_impl(Tensor& output, const Tensor& mean, const Tensor& std, std::optional<Generator> gen) {
|
| 246 |
+
CHECK_NORMAL_TENSOR_STD(std);
|
| 247 |
+
auto shape = at::infer_size(mean.sizes(), std.sizes());
|
| 248 |
+
at::native::resize_output(output, shape);
|
| 249 |
+
normal_impl_<normal_kernel, RNG>(output, 0, 1, gen);
|
| 250 |
+
// CUDA NB: addcmul_out copies the tensor to be added into the output.
|
| 251 |
+
// The previous function here was addcmul_out(output, mean, output, std, 1);
|
| 252 |
+
// The third argument is not a constant reference and hence the samples in output are overwritten.
|
| 253 |
+
// Consequently, the computation performed is mean + mean * std instead of mean + output * std
|
| 254 |
+
output.mul_(std).add_(mean);
|
| 255 |
+
return output;
|
| 256 |
+
}
|
| 257 |
+
|
| 258 |
+
template<template<typename> class normal_kernel, typename RNG>
|
| 259 |
+
Tensor normal_impl(const Tensor& mean, double std, std::optional<Generator> gen) {
|
| 260 |
+
CHECK_NORMAL_STD(std);
|
| 261 |
+
Tensor ret = at::empty_like(mean, MemoryFormat::Contiguous);
|
| 262 |
+
normal_out_impl<normal_kernel, RNG>(ret, mean, std, gen);
|
| 263 |
+
return ret;
|
| 264 |
+
}
|
| 265 |
+
|
| 266 |
+
template<template<typename> class normal_kernel, typename RNG>
|
| 267 |
+
Tensor normal_impl(double mean, const Tensor& std, std::optional<Generator> gen) {
|
| 268 |
+
CHECK_NORMAL_TENSOR_STD(std);
|
| 269 |
+
Tensor ret = at::empty_like(std, MemoryFormat::Contiguous);
|
| 270 |
+
normal_out_impl<normal_kernel, RNG>(ret, mean, std, gen);
|
| 271 |
+
return ret;
|
| 272 |
+
}
|
| 273 |
+
|
| 274 |
+
template<template<typename> class normal_kernel, typename RNG>
|
| 275 |
+
Tensor normal_impl(const Tensor& mean, const Tensor& std, std::optional<Generator> gen) {
|
| 276 |
+
CHECK_NORMAL_TENSOR_STD(std);
|
| 277 |
+
auto shape = at::infer_size(mean.sizes(), std.sizes());
|
| 278 |
+
Tensor ret = at::empty(shape, mean.options(), MemoryFormat::Contiguous);
|
| 279 |
+
normal_out_impl<normal_kernel, RNG>(ret, mean, std, gen);
|
| 280 |
+
return ret;
|
| 281 |
+
}
|
| 282 |
+
|
| 283 |
+
// ==================================================== Uniform =======================================================
|
| 284 |
+
|
| 285 |
+
template<template<typename> class uniform_kernel, typename RNG>
|
| 286 |
+
at::Tensor& uniform_impl_(at::Tensor& self, double from, double to, std::optional<Generator> generator) {
|
| 287 |
+
if (self.is_complex()) {
|
| 288 |
+
CHECK_EMPTY_AND_RETURN(self);
|
| 289 |
+
auto float_tensor = at::view_as_real(self);
|
| 290 |
+
uniform_impl_<uniform_kernel, RNG>(float_tensor, from, to, generator);
|
| 291 |
+
} else {
|
| 292 |
+
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, self.scalar_type(), "check_uniform_bounds", [&] {
|
| 293 |
+
[[maybe_unused]] const auto dtype = self.dtype();
|
| 294 |
+
const auto min = static_cast<double>(std::numeric_limits<scalar_t>::lowest());
|
| 295 |
+
const auto max = static_cast<double>(std::numeric_limits<scalar_t>::max());
|
| 296 |
+
CHECK_OUT_OF_BOUNDS(from, "from", min, max, dtype);
|
| 297 |
+
CHECK_OUT_OF_BOUNDS(to, "to", min, max, dtype);
|
| 298 |
+
TORCH_CHECK(from <= to, "uniform_ expects to return a [from, to) range, but found from=", from, " > to=", to);
|
| 299 |
+
TORCH_CHECK((to - from) <= std::numeric_limits<scalar_t>::max(),
|
| 300 |
+
"uniform_ expects to-from <= std::numeric_limits<", toString(self.scalar_type()),
|
| 301 |
+
">::max(), but found to=", to, " and from=", from,
|
| 302 |
+
" which result in to-from to exceed the limit");
|
| 303 |
+
from = std::min(std::max(from, min), max);
|
| 304 |
+
to = std::max(std::min(to, max), min);
|
| 305 |
+
});
|
| 306 |
+
CHECK_EMPTY_AND_RETURN(self);
|
| 307 |
+
auto iter = at::TensorIterator::borrowing_nullary_op(self);
|
| 308 |
+
uniform_kernel<RNG>()(iter, from, to, generator);
|
| 309 |
+
}
|
| 310 |
+
return self;
|
| 311 |
+
}
|
| 312 |
+
|
| 313 |
+
// ================================================== LogNormal =======================================================
|
| 314 |
+
|
| 315 |
+
template<template<typename> class log_normal_kernel, typename RNG>
|
| 316 |
+
at::Tensor& log_normal_impl_(at::Tensor& self, double mean, double std, std::optional<Generator> gen) {
|
| 317 |
+
TORCH_CHECK(std > 0.0, "log_normal_ expects std > 0.0, but found std=", std);
|
| 318 |
+
CHECK_EMPTY_AND_RETURN(self);
|
| 319 |
+
auto iter = TensorIterator::borrowing_nullary_op(self);
|
| 320 |
+
log_normal_kernel<RNG>()(iter, mean, std, gen);
|
| 321 |
+
return self;
|
| 322 |
+
}
|
| 323 |
+
|
| 324 |
+
// =================================================== Geometric ======================================================
|
| 325 |
+
|
| 326 |
+
template<template<typename> class geometric_kernel, typename RNG>
|
| 327 |
+
Tensor& geometric_impl_(Tensor& self, double p, std::optional<Generator> gen) {
|
| 328 |
+
TORCH_CHECK(0 < p && p < 1, "geometric_ expects p to be in (0, 1), but got p=", p);
|
| 329 |
+
CHECK_EMPTY_AND_RETURN(self);
|
| 330 |
+
auto iter = TensorIterator::borrowing_nullary_op(self);
|
| 331 |
+
geometric_kernel<RNG>()(iter, p, gen);
|
| 332 |
+
return self;
|
| 333 |
+
}
|
| 334 |
+
|
| 335 |
+
// ================================================== Exponential =====================================================
|
| 336 |
+
|
| 337 |
+
template<template<typename> class exponential_kernel, typename RNG>
|
| 338 |
+
Tensor& exponential_impl_(Tensor& self, double lambda, std::optional<Generator> gen) {
|
| 339 |
+
TORCH_CHECK(lambda > 0.0, "exponential_ expects lambda > 0.0, but found lambda=", lambda);
|
| 340 |
+
CHECK_EMPTY_AND_RETURN(self);
|
| 341 |
+
auto iter = TensorIterator::borrowing_nullary_op(self);
|
| 342 |
+
exponential_kernel<RNG>()(iter, lambda, gen);
|
| 343 |
+
return self;
|
| 344 |
+
}
|
| 345 |
+
|
| 346 |
+
// ==================================================== Cauchy ========================================================
|
| 347 |
+
|
| 348 |
+
template<template<typename> class cauchy_kernel, typename RNG>
|
| 349 |
+
Tensor& cauchy_impl_(Tensor& self, double median, double sigma, std::optional<Generator> gen) {
|
| 350 |
+
// TODO: instead of variable name 'sigma', use 'gamma' or 'scale'
|
| 351 |
+
// the variance, squared sigma, is undefined for cauchy distribution
|
| 352 |
+
TORCH_CHECK(sigma > 0.0, "cauchy_ expects sigma > 0.0, but found sigma=", sigma);
|
| 353 |
+
TORCH_CHECK(at::isFloatingType(self.scalar_type()), "Cauchy distribution is a continuous probability distribution. dtype must be a floating point but you specified ", self.dtype());
|
| 354 |
+
CHECK_EMPTY_AND_RETURN(self);
|
| 355 |
+
auto iter = TensorIterator::borrowing_nullary_op(self);
|
| 356 |
+
cauchy_kernel<RNG>()(iter, median, sigma, gen);
|
| 357 |
+
return self;
|
| 358 |
+
}
|
| 359 |
+
|
| 360 |
+
// ==================================================== Bernoulli =====================================================
|
| 361 |
+
|
| 362 |
+
template<template<typename> class bernoulli_tensor_kernel, typename RNG>
|
| 363 |
+
Tensor& bernoulli_impl_(Tensor& self, const Tensor& p_, std::optional<Generator> gen) {
|
| 364 |
+
CHECK_EMPTY_AND_RETURN(self);
|
| 365 |
+
NoNamesGuard guard;
|
| 366 |
+
at::assert_no_internal_overlap(self);
|
| 367 |
+
bernoulli_tensor_kernel<RNG>()(self, p_, gen);
|
| 368 |
+
return self;
|
| 369 |
+
}
|
| 370 |
+
|
| 371 |
+
template<template<typename> class bernoulli_scalar_kernel, typename RNG>
|
| 372 |
+
Tensor& bernoulli_impl_(Tensor& self, double p, std::optional<Generator> gen) {
|
| 373 |
+
TORCH_CHECK(0 <= p && p <= 1, "bernoulli_ expects p to be in [0, 1], but got p=", p);
|
| 374 |
+
CHECK_EMPTY_AND_RETURN(self);
|
| 375 |
+
at::assert_no_internal_overlap(self);
|
| 376 |
+
bernoulli_scalar_kernel<RNG>()(self, p, gen);
|
| 377 |
+
return self;
|
| 378 |
+
}
|
| 379 |
+
|
| 380 |
+
template<template<typename> class bernoulli_tensor_kernel, typename RNG>
|
| 381 |
+
Tensor& bernoulli_out_impl(Tensor& result, const Tensor& self, std::optional<Generator> gen) {
|
| 382 |
+
// result.resize_as_(self) requires self to have same dtype as result, so we
|
| 383 |
+
// use resize_ instead.
|
| 384 |
+
// TODO: Fix resize_as_. See pytorch/pytorch#11665.
|
| 385 |
+
result.resize_(self.sizes());
|
| 386 |
+
bernoulli_impl_<bernoulli_tensor_kernel, RNG>(result, self, gen);
|
| 387 |
+
namedinference::propagate_names(result, self);
|
| 388 |
+
return result;
|
| 389 |
+
}
|
| 390 |
+
|
| 391 |
+
#undef CHECK_OUT_OF_BOUNDS
|
| 392 |
+
#undef WARN_OUT_OF_BOUNDS
|
| 393 |
+
|
| 394 |
+
} // namespace at::native::templates
|
.venv/lib/python3.11/site-packages/torch/include/ATen/native/Distributions.h
ADDED
|
@@ -0,0 +1,518 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/native/Math.h>
|
| 4 |
+
#include <c10/macros/Macros.h>
|
| 5 |
+
#include <c10/util/MathConstants.h>
|
| 6 |
+
|
| 7 |
+
// ROCM hcc doesn't work well with using std:: in kernel functions
|
| 8 |
+
#if defined(__CUDA_ARCH__)
|
| 9 |
+
#include <c10/cuda/CUDAMathCompat.h>
|
| 10 |
+
#define compat_exp c10::cuda::compat::exp
|
| 11 |
+
#define compat_ceil c10::cuda::compat::ceil
|
| 12 |
+
#define compat_floor c10::cuda::compat::floor
|
| 13 |
+
#define compat_log c10::cuda::compat::log
|
| 14 |
+
#define compat_pow c10::cuda::compat::pow
|
| 15 |
+
#define compat_sqrt c10::cuda::compat::sqrt
|
| 16 |
+
#define compat_tan c10::cuda::compat::tan
|
| 17 |
+
#define compat_abs c10::cuda::compat::abs
|
| 18 |
+
#define compat_log1p c10::cuda::compat::log1p
|
| 19 |
+
#elif defined(__HIPCC__)
|
| 20 |
+
#include <c10/hip/HIPMathCompat.h>
|
| 21 |
+
#define compat_exp c10::hip::compat::exp
|
| 22 |
+
#define compat_ceil c10::hip::compat::ceil
|
| 23 |
+
#define compat_floor c10::hip::compat::floor
|
| 24 |
+
#define compat_log c10::hip::compat::log
|
| 25 |
+
#define compat_pow c10::hip::compat::pow
|
| 26 |
+
#define compat_sqrt c10::hip::compat::sqrt
|
| 27 |
+
#define compat_tan c10::hip::compat::tan
|
| 28 |
+
#define compat_abs c10::hip::compat::abs
|
| 29 |
+
#define compat_log1p c10::hip::compat::log1p
|
| 30 |
+
#else
|
| 31 |
+
#define compat_exp std::exp
|
| 32 |
+
#define compat_ceil std::ceil
|
| 33 |
+
#define compat_floor std::floor
|
| 34 |
+
#define compat_log std::log
|
| 35 |
+
#define compat_pow std::pow
|
| 36 |
+
#define compat_sqrt std::sqrt
|
| 37 |
+
#define compat_tan std::tan
|
| 38 |
+
#define compat_abs std::abs
|
| 39 |
+
#define compat_log1p std::log1p
|
| 40 |
+
#endif
|
| 41 |
+
|
| 42 |
+
namespace {
|
| 43 |
+
|
| 44 |
+
#if !defined(__CUDA_ARCH__) && !defined(__HIPCC__)
|
| 45 |
+
// we cannot use std::isnan directly due to some incompatibility of
|
| 46 |
+
// gcc constexpr'ing and nvcc
|
| 47 |
+
using std::isnan;
|
| 48 |
+
#endif
|
| 49 |
+
|
| 50 |
+
// Here sampler_t should be function type scalar_t(void). For gpu
|
| 51 |
+
// "sampler" is a device function, but since ROCM doesn't have
|
| 52 |
+
// equivalent to nvstd::function, we use a template type parameter to
|
| 53 |
+
// capture it.
|
| 54 |
+
template<typename scalar_t, typename sampler_t>
|
| 55 |
+
struct BaseSampler {
|
| 56 |
+
sampler_t sampler;
|
| 57 |
+
C10_DEVICE BaseSampler(const sampler_t& sampler): sampler(sampler) {}
|
| 58 |
+
C10_DEVICE scalar_t sample() {
|
| 59 |
+
return sampler();
|
| 60 |
+
}
|
| 61 |
+
};
|
| 62 |
+
|
| 63 |
+
// The function `sample_gamma` is
|
| 64 |
+
// is adapted from Numpy's distributions.c implementation.
|
| 65 |
+
// It is MIT licensed, so here is the copyright:
|
| 66 |
+
|
| 67 |
+
/* Copyright 2005 Robert Kern (robert.kern@gmail.com)
|
| 68 |
+
*
|
| 69 |
+
* Permission is hereby granted, free of charge, to any person obtaining a
|
| 70 |
+
* copy of this software and associated documentation files (the
|
| 71 |
+
* "Software"), to deal in the Software without restriction, including
|
| 72 |
+
* without limitation the rights to use, copy, modify, merge, publish,
|
| 73 |
+
* distribute, sublicense, and/or sell copies of the Software, and to
|
| 74 |
+
* permit persons to whom the Software is furnished to do so, subject to
|
| 75 |
+
* the following conditions:
|
| 76 |
+
*
|
| 77 |
+
* The above copyright notice and this permission notice shall be included
|
| 78 |
+
* in all copies or substantial portions of the Software.
|
| 79 |
+
*
|
| 80 |
+
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
| 81 |
+
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
| 82 |
+
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
| 83 |
+
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
| 84 |
+
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
| 85 |
+
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
| 86 |
+
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
| 87 |
+
*/
|
| 88 |
+
|
| 89 |
+
template<typename scalar_t, typename accscalar_t, typename uniform_sampler_t, typename normal_sampler_t>
|
| 90 |
+
C10_DEVICE scalar_t sample_gamma(scalar_t alpha, BaseSampler<accscalar_t, uniform_sampler_t>& standard_uniform, BaseSampler<accscalar_t, normal_sampler_t>& standard_normal) {
|
| 91 |
+
accscalar_t scale = 1.0f;
|
| 92 |
+
|
| 93 |
+
// Boost alpha for higher acceptance probability.
|
| 94 |
+
if (alpha < 1.0f) {
|
| 95 |
+
if (alpha == 0.f) return 0.f;
|
| 96 |
+
scale *= compat_pow(1 - standard_uniform.sample(), 1.0f / alpha);
|
| 97 |
+
alpha += 1.0f;
|
| 98 |
+
}
|
| 99 |
+
|
| 100 |
+
// This implements the acceptance-rejection method of Marsaglia and Tsang (2000)
|
| 101 |
+
// doi:10.1145/358407.358414
|
| 102 |
+
const accscalar_t d = alpha - 1.0f / 3.0f;
|
| 103 |
+
const accscalar_t c = 1.0f / compat_sqrt(9.0f * d);
|
| 104 |
+
for (;;) {
|
| 105 |
+
accscalar_t x, y;
|
| 106 |
+
do {
|
| 107 |
+
x = standard_normal.sample();
|
| 108 |
+
y = 1.0f + c * x;
|
| 109 |
+
} while (y <= 0);
|
| 110 |
+
const accscalar_t v = y * y * y;
|
| 111 |
+
const accscalar_t u = 1 - standard_uniform.sample();
|
| 112 |
+
const accscalar_t xx = x * x;
|
| 113 |
+
if (u < 1.0f - 0.0331f * xx * xx)
|
| 114 |
+
return static_cast<scalar_t>(scale * d * v);
|
| 115 |
+
if (compat_log(u) < 0.5f * xx + d * (1.0f - v + compat_log(v)))
|
| 116 |
+
return static_cast<scalar_t>(scale * d * v);
|
| 117 |
+
}
|
| 118 |
+
}
|
| 119 |
+
|
| 120 |
+
/* the functions stirling_approx_tail, binomial_inversion, and btrs are adapted
|
| 121 |
+
* from TensorFlow's random_binomial_op.cc implementation. That code is under
|
| 122 |
+
* copyright: 2019 The TensorFlow Authors.
|
| 123 |
+
*
|
| 124 |
+
* It was released under the Apache License, Version 2.0 (the "License"), available at:
|
| 125 |
+
* http://www.apache.org/licenses/LICENSE-2.0
|
| 126 |
+
*/
|
| 127 |
+
|
| 128 |
+
template<typename scalar_t>
|
| 129 |
+
C10_DEVICE scalar_t stirling_approx_tail(scalar_t k) {
|
| 130 |
+
const static scalar_t kTailValues[] = {
|
| 131 |
+
0.0810614667953272,
|
| 132 |
+
0.0413406959554092,
|
| 133 |
+
0.0276779256849983,
|
| 134 |
+
0.02079067210376509,
|
| 135 |
+
0.0166446911898211,
|
| 136 |
+
0.0138761288230707,
|
| 137 |
+
0.0118967099458917,
|
| 138 |
+
0.0104112652619720,
|
| 139 |
+
0.00925546218271273,
|
| 140 |
+
0.00833056343336287
|
| 141 |
+
};
|
| 142 |
+
if (k <= 9) {
|
| 143 |
+
return kTailValues[static_cast<size_t>(k)];
|
| 144 |
+
}
|
| 145 |
+
scalar_t kp1sq = (k + 1) * (k + 1);
|
| 146 |
+
return (1.0 / 12 - (1.0 / 360 - 1.0 / 1260 / kp1sq) / kp1sq) / (k + 1);
|
| 147 |
+
}
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
template<typename scalar_t, typename accscalar_t, typename uniform_sampler_t>
|
| 151 |
+
C10_DEVICE scalar_t binomial_inversion(scalar_t count, scalar_t prob, BaseSampler<accscalar_t, uniform_sampler_t>& standard_uniform) {
|
| 152 |
+
accscalar_t U;
|
| 153 |
+
accscalar_t geom_sum = 0;
|
| 154 |
+
scalar_t num_geom = 0;
|
| 155 |
+
|
| 156 |
+
accscalar_t logprob = compat_log1p(-prob);
|
| 157 |
+
|
| 158 |
+
while (1) {
|
| 159 |
+
U = standard_uniform.sample();
|
| 160 |
+
accscalar_t geom = compat_ceil(compat_log(U) / logprob);
|
| 161 |
+
geom_sum += geom;
|
| 162 |
+
if (geom_sum > count) {
|
| 163 |
+
break;
|
| 164 |
+
}
|
| 165 |
+
num_geom = num_geom + 1;
|
| 166 |
+
}
|
| 167 |
+
return num_geom;
|
| 168 |
+
}
|
| 169 |
+
|
| 170 |
+
template<typename scalar_t, typename accscalar_t, typename uniform_sampler_t>
|
| 171 |
+
C10_DEVICE scalar_t btrs(scalar_t count, scalar_t prob, BaseSampler<accscalar_t, uniform_sampler_t>& standard_uniform) {
|
| 172 |
+
scalar_t k;
|
| 173 |
+
accscalar_t U, V, us;
|
| 174 |
+
|
| 175 |
+
// This is spq in the paper.
|
| 176 |
+
const accscalar_t stddev = compat_sqrt(count * prob * (1 - prob));
|
| 177 |
+
|
| 178 |
+
// Other coefficients for Transformed Rejection sampling.
|
| 179 |
+
const accscalar_t b = 1.15 + 2.53 * stddev;
|
| 180 |
+
const accscalar_t a = -0.0873 + 0.0248 * b + 0.01 * prob;
|
| 181 |
+
const accscalar_t c = count * prob + 0.5;
|
| 182 |
+
const accscalar_t v_r = 0.92 - 4.2 / b;
|
| 183 |
+
const accscalar_t r = prob / (1 - prob);
|
| 184 |
+
|
| 185 |
+
const accscalar_t alpha = (2.83 + 5.1 / b) * stddev;
|
| 186 |
+
const accscalar_t m = compat_floor((count + 1) * prob);
|
| 187 |
+
|
| 188 |
+
while (1) {
|
| 189 |
+
U = standard_uniform.sample() - 0.5;
|
| 190 |
+
V = standard_uniform.sample();
|
| 191 |
+
|
| 192 |
+
us = 0.5 - compat_abs(U);
|
| 193 |
+
k = static_cast<scalar_t>(compat_floor((2 * a / us + b) * U + c));
|
| 194 |
+
|
| 195 |
+
// Reject non-sensical answers.
|
| 196 |
+
if (k < 0 || k > count) {
|
| 197 |
+
continue;
|
| 198 |
+
}
|
| 199 |
+
// Region for which the box is tight, and we can return our calculated value.
|
| 200 |
+
// This should happen 0.86 * v_r times. In the limit as n * p is large,
|
| 201 |
+
// the acceptance rate converges to ~79% (and in the lower regime it is ~24%).
|
| 202 |
+
if (us >= 0.07 && V <= v_r) {
|
| 203 |
+
return k;
|
| 204 |
+
}
|
| 205 |
+
|
| 206 |
+
// This deviates from Hormann's BTRS algorithm, as there is a log missing.
|
| 207 |
+
// For all (u, v) pairs outside of the bounding box, this calculates the
|
| 208 |
+
// transformed-reject ratio.
|
| 209 |
+
V = compat_log(V * alpha / (a / (us * us) + b));
|
| 210 |
+
accscalar_t upperbound =
|
| 211 |
+
((m + 0.5) * compat_log((m + 1) / (r * (count - m + 1))) +
|
| 212 |
+
(count + 1) * compat_log((count - m + 1) / (count - k + 1)) +
|
| 213 |
+
(k + 0.5) * compat_log(r * (count - k + 1) / (k + 1)) +
|
| 214 |
+
stirling_approx_tail<accscalar_t>(m) + stirling_approx_tail<accscalar_t>(count - m) -
|
| 215 |
+
stirling_approx_tail<accscalar_t>(k) - stirling_approx_tail<accscalar_t>(count - k));
|
| 216 |
+
|
| 217 |
+
if (V <= upperbound) {
|
| 218 |
+
return k;
|
| 219 |
+
}
|
| 220 |
+
}
|
| 221 |
+
}
|
| 222 |
+
|
| 223 |
+
template<typename scalar_t, typename accscalar_t, typename uniform_sampler_t>
|
| 224 |
+
C10_DEVICE scalar_t sample_binomial(scalar_t count, scalar_t prob, BaseSampler<accscalar_t, uniform_sampler_t>& standard_uniform) {
|
| 225 |
+
if (count <= 0.0 || prob <= 0.0) {
|
| 226 |
+
return 0;
|
| 227 |
+
} else if (prob >= 1.0) {
|
| 228 |
+
return count;
|
| 229 |
+
} else if (prob <= 0.5) {
|
| 230 |
+
if (count * prob >= 10.0) {
|
| 231 |
+
// btrs
|
| 232 |
+
return btrs<scalar_t, accscalar_t, uniform_sampler_t>(count, prob, standard_uniform);
|
| 233 |
+
} else {
|
| 234 |
+
// binomial inversion
|
| 235 |
+
return binomial_inversion<scalar_t, accscalar_t, uniform_sampler_t>(count, prob, standard_uniform);
|
| 236 |
+
}
|
| 237 |
+
} else if (prob > 0.5) {
|
| 238 |
+
scalar_t qprob = 1.0 - prob;
|
| 239 |
+
if (count * qprob >= 10.0) {
|
| 240 |
+
// btrs
|
| 241 |
+
return count - btrs<scalar_t, accscalar_t, uniform_sampler_t>(count, qprob, standard_uniform);
|
| 242 |
+
} else {
|
| 243 |
+
// count - binomial inversion
|
| 244 |
+
return count - binomial_inversion<scalar_t, accscalar_t, uniform_sampler_t>(count, qprob, standard_uniform);
|
| 245 |
+
}
|
| 246 |
+
} else {
|
| 247 |
+
// prob is nan?
|
| 248 |
+
return static_cast<scalar_t>(NAN);
|
| 249 |
+
}
|
| 250 |
+
}
|
| 251 |
+
|
| 252 |
+
/*
|
| 253 |
+
* This function is derived from the implementation of the digamma function in the Cephes Math Library.
|
| 254 |
+
* See note [3-Clause BSD License for the Cephes Math Library] in ATen/native/Math.h.
|
| 255 |
+
*/
|
| 256 |
+
template<typename scalar_t, typename accscalar_t>
|
| 257 |
+
C10_DEVICE inline scalar_t digamma_one(scalar_t x) {
|
| 258 |
+
constexpr accscalar_t PSI_10 = 2.25175258906672110764;
|
| 259 |
+
if (x == 0) {
|
| 260 |
+
return INFINITY;
|
| 261 |
+
}
|
| 262 |
+
accscalar_t additional_summand = 0;
|
| 263 |
+
int x_is_integer = x == compat_floor(x);
|
| 264 |
+
if (x < 0) {
|
| 265 |
+
if (x_is_integer) {
|
| 266 |
+
return INFINITY;
|
| 267 |
+
}
|
| 268 |
+
// it is more standard to write this as recursion, but
|
| 269 |
+
// nvcc does not like that
|
| 270 |
+
additional_summand = -c10::pi<scalar_t> /
|
| 271 |
+
compat_tan(c10::pi<scalar_t> * x);
|
| 272 |
+
x = 1 - x;
|
| 273 |
+
}
|
| 274 |
+
|
| 275 |
+
// Push x to be >= 10
|
| 276 |
+
accscalar_t result = 0;
|
| 277 |
+
while (x < 10) {
|
| 278 |
+
result -= 1 / x;
|
| 279 |
+
x += 1;
|
| 280 |
+
}
|
| 281 |
+
if (x == 10) {
|
| 282 |
+
return result + PSI_10 + additional_summand;
|
| 283 |
+
}
|
| 284 |
+
|
| 285 |
+
// Compute asymptotic digamma
|
| 286 |
+
static const accscalar_t A[] = {
|
| 287 |
+
8.33333333333333333333E-2,
|
| 288 |
+
-2.10927960927960927961E-2,
|
| 289 |
+
7.57575757575757575758E-3,
|
| 290 |
+
-4.16666666666666666667E-3,
|
| 291 |
+
3.96825396825396825397E-3,
|
| 292 |
+
-8.33333333333333333333E-3,
|
| 293 |
+
8.33333333333333333333E-2,
|
| 294 |
+
};
|
| 295 |
+
|
| 296 |
+
accscalar_t y = 0;
|
| 297 |
+
if (x < 1.0e17f) {
|
| 298 |
+
accscalar_t z = 1.0 / (x * x);
|
| 299 |
+
y = z * polevl<accscalar_t>(z, A, 6);
|
| 300 |
+
}
|
| 301 |
+
return static_cast<scalar_t>(
|
| 302 |
+
result + compat_log(x) - (0.5f / x) - y + additional_summand);
|
| 303 |
+
}
|
| 304 |
+
|
| 305 |
+
// Computes the reparameterized gradient -(d/dalpha cdf(x;alpha)) / pdf(x;alpha)
|
| 306 |
+
// for random number x drawn from a standard Gamma distribution Gamma(alpha).
|
| 307 |
+
template <typename scalar_t, typename accscalar_t>
|
| 308 |
+
C10_HOST_DEVICE scalar_t standard_gamma_grad_one(scalar_t alpha_, scalar_t x_) {
|
| 309 |
+
// Use a Taylor series expansion for small x.
|
| 310 |
+
accscalar_t x = static_cast<accscalar_t>(x_);
|
| 311 |
+
accscalar_t alpha = static_cast<accscalar_t>(alpha_);
|
| 312 |
+
if (x < 0.8f) {
|
| 313 |
+
accscalar_t numer = 1;
|
| 314 |
+
accscalar_t denom = alpha;
|
| 315 |
+
auto series1 = numer / denom;
|
| 316 |
+
auto series2 = numer / (denom * denom);
|
| 317 |
+
for (int i = 1; i <= 5; ++i) {
|
| 318 |
+
numer *= -x / static_cast<accscalar_t>(i);
|
| 319 |
+
denom += 1;
|
| 320 |
+
series1 += numer / denom;
|
| 321 |
+
series2 += numer / (denom * denom);
|
| 322 |
+
}
|
| 323 |
+
const auto pow_x_alpha = compat_pow(x, alpha);
|
| 324 |
+
const auto gamma_pdf = compat_pow(x, alpha - 1) * compat_exp(-x);
|
| 325 |
+
const auto gamma_cdf = pow_x_alpha * series1;
|
| 326 |
+
const auto gamma_cdf_alpha =
|
| 327 |
+
(compat_log(x) - digamma_one<accscalar_t, accscalar_t>(alpha)) *
|
| 328 |
+
gamma_cdf -
|
| 329 |
+
pow_x_alpha * series2;
|
| 330 |
+
const auto result = -gamma_cdf_alpha / gamma_pdf;
|
| 331 |
+
return isnan(result) ? static_cast<scalar_t>( 0.f ) : static_cast<scalar_t>(result);
|
| 332 |
+
}
|
| 333 |
+
|
| 334 |
+
// Use a Rice saddle point expansion for large alpha.
|
| 335 |
+
if (alpha > 8.0f) {
|
| 336 |
+
if (0.9f * alpha <= x && x <= 1.1f * alpha) {
|
| 337 |
+
const auto numer_1 = 1 + 24 * alpha * (1 + 12 * alpha);
|
| 338 |
+
const auto numer_2 = 1440 * (alpha * alpha) + 6 * x * (53 - 120 * x)
|
| 339 |
+
- 65 * x * x / alpha + alpha * (107 + 3600 * x);
|
| 340 |
+
const auto denom = 1244160 * (alpha * alpha) * (alpha * alpha);
|
| 341 |
+
return static_cast<scalar_t>(numer_1 * numer_2 / denom);
|
| 342 |
+
}
|
| 343 |
+
const auto denom = compat_sqrt(8 * alpha);
|
| 344 |
+
const auto term2 = denom / (alpha - x);
|
| 345 |
+
const auto term3 = compat_pow(
|
| 346 |
+
x - alpha - alpha * compat_log(x / alpha),
|
| 347 |
+
static_cast<accscalar_t>(-1.5));
|
| 348 |
+
const auto term23 = (x < alpha) ? term2 - term3 : term2 + term3;
|
| 349 |
+
const auto term1 = compat_log(x / alpha) * term23 -
|
| 350 |
+
compat_sqrt(2 / alpha) * (alpha + x) / ((alpha - x) * (alpha - x));
|
| 351 |
+
const auto stirling = 1 + 1 / (12 * alpha) * (1 + 1 / (24 * alpha));
|
| 352 |
+
const auto numer = x * term1;
|
| 353 |
+
return static_cast<scalar_t>(-stirling * numer / denom);
|
| 354 |
+
}
|
| 355 |
+
|
| 356 |
+
// Use a bivariate rational approximation to the reparameterized gradient.
|
| 357 |
+
const auto u = compat_log(x / alpha);
|
| 358 |
+
const auto v = compat_log(alpha);
|
| 359 |
+
static const accscalar_t coef_uv[3][8] = {
|
| 360 |
+
{0.16009398, -0.094634809, 0.025146376, -0.0030648343,
|
| 361 |
+
1, 0.32668115, 0.10406089, 0.0014179084},
|
| 362 |
+
{0.53487893, 0.1298071, 0.065735949, -0.0015649758,
|
| 363 |
+
0.16639465, 0.020070113, -0.0035938915, -0.00058392623},
|
| 364 |
+
{0.040121004, -0.0065914022, -0.0026286047, -0.0013441777,
|
| 365 |
+
0.017050642, -0.0021309326, 0.00085092367, -1.5247877e-07},
|
| 366 |
+
};
|
| 367 |
+
accscalar_t coef_v[8];
|
| 368 |
+
for (int i = 0; i < 8; ++ i) {
|
| 369 |
+
coef_v[i] = coef_uv[0][i] + u * (coef_uv[1][i] + u * coef_uv[2][i]);
|
| 370 |
+
}
|
| 371 |
+
const auto p = coef_v[0] + v * (coef_v[1] + v * (coef_v[2] + v * coef_v[3]));
|
| 372 |
+
const auto q = coef_v[4] + v * (coef_v[5] + v * (coef_v[6] + v * coef_v[7]));
|
| 373 |
+
return static_cast<scalar_t>(compat_exp(p / q));
|
| 374 |
+
}
|
| 375 |
+
|
| 376 |
+
// Approximate reparameterized gradient of Beta(x,alpha,beta) wrt alpha.
|
| 377 |
+
// Assumes x is close to zero and uses a Taylor expansion.
|
| 378 |
+
template <typename scalar_t, typename accscalar_t>
|
| 379 |
+
C10_DEVICE inline scalar_t _beta_grad_alpha_small(scalar_t x, scalar_t alpha, scalar_t beta) {
|
| 380 |
+
const scalar_t factor = digamma_one<scalar_t, accscalar_t>(alpha)
|
| 381 |
+
- digamma_one<scalar_t, accscalar_t>(alpha + beta) - compat_log(x);
|
| 382 |
+
scalar_t numer = 1;
|
| 383 |
+
scalar_t series = numer / alpha * (factor + 1 / alpha);
|
| 384 |
+
for (int i = 1; i <= 10; ++i) {
|
| 385 |
+
scalar_t casted_i = static_cast<scalar_t>(i);
|
| 386 |
+
numer *= (casted_i - beta) * x / casted_i;
|
| 387 |
+
const scalar_t denom = alpha + casted_i;
|
| 388 |
+
series += numer / denom * (factor + 1 / denom);
|
| 389 |
+
}
|
| 390 |
+
const scalar_t result = x * compat_pow(1 - x, -beta) * series;
|
| 391 |
+
return isnan(result) ? static_cast<scalar_t>( 0.f ) : result;
|
| 392 |
+
}
|
| 393 |
+
|
| 394 |
+
// Approximate reparameterized gradient of Beta(x,alpha,beta) wrt beta.
|
| 395 |
+
// Assumes x is close to zero and uses a Taylor expansion.
|
| 396 |
+
template <typename scalar_t, typename accscalar_t>
|
| 397 |
+
C10_DEVICE inline scalar_t _beta_grad_beta_small(scalar_t x, scalar_t alpha, scalar_t beta) {
|
| 398 |
+
const scalar_t factor = digamma_one<scalar_t, accscalar_t>(alpha + beta) - digamma_one<scalar_t, accscalar_t>(beta);
|
| 399 |
+
scalar_t numer = 1, betas = 1, dbetas = 0, series = factor / alpha;
|
| 400 |
+
for (int i = 1; i <= 8; ++i) {
|
| 401 |
+
scalar_t casted_i = static_cast<scalar_t>(i);
|
| 402 |
+
numer *= -x / casted_i;
|
| 403 |
+
dbetas = dbetas * (beta - casted_i) + betas;
|
| 404 |
+
betas = betas * (beta - casted_i);
|
| 405 |
+
series += numer / (alpha + casted_i) * (dbetas + factor * betas);
|
| 406 |
+
}
|
| 407 |
+
const scalar_t result = -compat_pow(1 - x, 1 - beta) * series;
|
| 408 |
+
return isnan(result) ? static_cast<scalar_t>( 0.f ) : result;
|
| 409 |
+
}
|
| 410 |
+
|
| 411 |
+
// Approximate reparameterized gradient of Beta(x,alpha,beta) wrt alpha.
|
| 412 |
+
// Assumes alpha and beta are both large and uses a Rice saddle point expansion.
|
| 413 |
+
// To ensure numerical stability, this computation is performed at higher precision.
|
| 414 |
+
template<typename scalar_t, typename accscalar_t>
|
| 415 |
+
C10_DEVICE inline scalar_t _beta_grad_alpha_mid(accscalar_t x, accscalar_t alpha, accscalar_t beta) {
|
| 416 |
+
const accscalar_t total = alpha + beta;
|
| 417 |
+
const accscalar_t mean = alpha / total;
|
| 418 |
+
const accscalar_t std = compat_sqrt(alpha * beta / (total + 1)) / total;
|
| 419 |
+
if (mean - 0.1 * std <= x && x <= mean + 0.1 * std) {
|
| 420 |
+
// Avoid the singularity at x = mean.
|
| 421 |
+
const accscalar_t poly = 47 * x * (beta * beta) * (beta * beta) + alpha * (
|
| 422 |
+
(43 + 20 * (16 + 27 * beta) * x) * (beta * beta) * beta + alpha * (
|
| 423 |
+
3 * (59 + 180 * beta - 90 * x) * (beta * beta) + alpha * (
|
| 424 |
+
(453 + 1620 * beta * (1 - x) - 455 * x) * beta + alpha * (
|
| 425 |
+
8 * (1 - x) * (135 * beta - 11)))));
|
| 426 |
+
const accscalar_t prefactor_num = (1 + 12 * alpha) * (1 + 12 * beta) / (total * total);
|
| 427 |
+
const accscalar_t prefactor_den = 12960 * alpha * alpha * alpha * beta * beta * (1 + 12 * total);
|
| 428 |
+
return prefactor_num / (1 - x) * poly / prefactor_den;
|
| 429 |
+
}
|
| 430 |
+
const accscalar_t prefactor = -x / compat_sqrt(2 * alpha * beta / total);
|
| 431 |
+
const accscalar_t stirling = (1 + 1 / (12 * alpha) + 1 / (288 * alpha * alpha))
|
| 432 |
+
* (1 + 1 / (12 * beta) + 1 / (288 * beta * beta))
|
| 433 |
+
/ (1 + 1 / (12 * total) + 1 / (288 * total * total));
|
| 434 |
+
const accscalar_t term1_num = 2 * (alpha * alpha) * (x - 1) + alpha * beta * (x - 1) - x * (beta * beta);
|
| 435 |
+
const accscalar_t axbx = alpha * (x - 1) + beta * x;
|
| 436 |
+
const accscalar_t term1_den = compat_sqrt(2 * alpha / beta) * compat_pow(total, static_cast<accscalar_t>(1.5f)) * axbx * axbx;
|
| 437 |
+
const accscalar_t term1 = term1_num / term1_den;
|
| 438 |
+
const accscalar_t term2 = 0.5f * compat_log(alpha / (total * x));
|
| 439 |
+
const accscalar_t term3_num = compat_sqrt(8 * alpha * beta / total);
|
| 440 |
+
const accscalar_t term3_den = beta * x + alpha * (x - 1);
|
| 441 |
+
const accscalar_t term3 = term3_num / term3_den;
|
| 442 |
+
const accscalar_t term4_base = beta * compat_log(beta / (total * (1 - x))) +
|
| 443 |
+
alpha * compat_log(alpha / (total * x));
|
| 444 |
+
const accscalar_t term4 = compat_pow(term4_base, static_cast<accscalar_t>(-1.5f));
|
| 445 |
+
const accscalar_t term1234 = term1 + term2 * (term3 + (x < mean ? term4 : -term4));
|
| 446 |
+
return static_cast<scalar_t>(stirling * prefactor * term1234);
|
| 447 |
+
}
|
| 448 |
+
|
| 449 |
+
// Computes a scaled reparameterized gradient
|
| 450 |
+
// -(d/dalpha cdf(x;alpha,beta)) / pdf(x;alpha,beta) / (1-x)
|
| 451 |
+
// for random number x drawn from a Beta distribution Beta(alpha,beta).
|
| 452 |
+
// This function inputs total=alpha+beta to make it easy to implement
|
| 453 |
+
// Dirichlet reparameterized gradients in terms of Betas.
|
| 454 |
+
template<typename scalar_t, typename accscalar_t>
|
| 455 |
+
C10_HOST_DEVICE inline scalar_t dirichlet_grad_one(scalar_t x, scalar_t alpha, scalar_t total) {
|
| 456 |
+
accscalar_t x_ = static_cast<accscalar_t>(x);
|
| 457 |
+
accscalar_t alpha_ = static_cast<accscalar_t>(alpha);
|
| 458 |
+
accscalar_t total_ = static_cast<accscalar_t>(total);
|
| 459 |
+
|
| 460 |
+
const scalar_t beta = total - alpha;
|
| 461 |
+
const accscalar_t beta_ = total_ - alpha_;
|
| 462 |
+
const scalar_t boundary = total * x * (1 - x);
|
| 463 |
+
|
| 464 |
+
// Use an asymptotic approximation for x close to 0.
|
| 465 |
+
if (x <= 0.5f && boundary < 2.5f) {
|
| 466 |
+
return _beta_grad_alpha_small<scalar_t, accscalar_t>(x, alpha, beta);
|
| 467 |
+
}
|
| 468 |
+
|
| 469 |
+
// Use an asymptotic approximation for x close to 1.
|
| 470 |
+
if (x >= 0.5f && boundary < 0.75f) {
|
| 471 |
+
return -_beta_grad_beta_small<scalar_t, accscalar_t>(1 - x, beta, alpha);
|
| 472 |
+
}
|
| 473 |
+
|
| 474 |
+
// Use an asymptotic approximation when alpha and (total - alpha) are both large.
|
| 475 |
+
if (alpha > 6 && beta > 6) {
|
| 476 |
+
return _beta_grad_alpha_mid<scalar_t, accscalar_t>(x_, alpha_, beta_);
|
| 477 |
+
}
|
| 478 |
+
|
| 479 |
+
// Use a rational correction to an analytic approximation.
|
| 480 |
+
static const accscalar_t c[2][3][3][4] = {
|
| 481 |
+
{{{1.003668233, -0.01061107488, -0.0657888334, 0.01201642863},
|
| 482 |
+
{0.6336835991, -0.3557432599, 0.05486251648, -0.001465281033},
|
| 483 |
+
{-0.03276231906, 0.004474107445, 0.002429354597, -0.0001557569013}},
|
| 484 |
+
{{0.221950385, -0.3187676331, 0.01799915743, 0.01074823814},
|
| 485 |
+
{-0.2951249643, 0.06219954479, 0.01535556598, 0.001550077057},
|
| 486 |
+
{0.02155310298, 0.004170831599, 0.001292462449, 6.976601077e-05}},
|
| 487 |
+
{{-0.05980841433, 0.008441916499, 0.01085618172, 0.002319392565},
|
| 488 |
+
{0.02911413504, 0.01400243777, -0.002721828457, 0.000751041181},
|
| 489 |
+
{0.005900514878, -0.001936558688, -9.495446725e-06, 5.385558597e-05}}},
|
| 490 |
+
{{{1, -0.02924021934, -0.04438342661, 0.007285809825},
|
| 491 |
+
{0.6357567472, -0.3473456711, 0.05454656494, -0.002407477521},
|
| 492 |
+
{-0.03301322327, 0.004845219414, 0.00231480583, -0.0002307248149}},
|
| 493 |
+
{{0.5925320577, -0.1757678135, 0.01505928619, 0.000564515273},
|
| 494 |
+
{0.1014815858, -0.06589186703, 0.01272886114, -0.0007316646956},
|
| 495 |
+
{-0.007258481865, 0.001096195486, 0.0003934994223, -4.12701925e-05}},
|
| 496 |
+
{{0.06469649321, -0.0236701437, 0.002902096474, -5.896963079e-05},
|
| 497 |
+
{0.001925008108, -0.002869809258, 0.0008000589141, -6.063713228e-05},
|
| 498 |
+
{-0.0003477407336, 6.959756487e-05, 1.097287507e-05, -1.650964693e-06}}},
|
| 499 |
+
};
|
| 500 |
+
const accscalar_t u = compat_log(x_);
|
| 501 |
+
const accscalar_t a = compat_log(alpha_) - u;
|
| 502 |
+
const accscalar_t b = compat_log(total_) - a;
|
| 503 |
+
const accscalar_t pow_u[3] = {1, u, u * u};
|
| 504 |
+
const accscalar_t pow_a[3] = {1, a, a * a};
|
| 505 |
+
accscalar_t p = 0.0;
|
| 506 |
+
accscalar_t q = 0.0;
|
| 507 |
+
for (int i = 0; i < 3; ++i) {
|
| 508 |
+
for (int j = 0; j < 3; ++j) {
|
| 509 |
+
const accscalar_t ua = pow_u[i] * pow_a[j];
|
| 510 |
+
p += ua * (c[0][i][j][0] + b * (c[0][i][j][1] + b * (c[0][i][j][2] + b * c[0][i][j][3])));
|
| 511 |
+
q += ua * (c[1][i][j][0] + b * (c[1][i][j][1] + b * (c[1][i][j][2] + b * c[1][i][j][3])));
|
| 512 |
+
}
|
| 513 |
+
}
|
| 514 |
+
const accscalar_t approx = x_ * (digamma_one<scalar_t, accscalar_t>(total_) - digamma_one<scalar_t, accscalar_t>(alpha_)) / beta_;
|
| 515 |
+
return static_cast<scalar_t>(p / q * approx);
|
| 516 |
+
}
|
| 517 |
+
|
| 518 |
+
} // namespace
|
.venv/lib/python3.11/site-packages/torch/include/ATen/native/EmbeddingBag.h
ADDED
|
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <ATen/core/Tensor.h>
|
| 2 |
+
#include <ATen/Config.h>
|
| 3 |
+
#include <cstdint>
|
| 4 |
+
|
| 5 |
+
#ifdef USE_FBGEMM
|
| 6 |
+
#include <fbgemm/FbgemmEmbedding.h>
|
| 7 |
+
#endif
|
| 8 |
+
|
| 9 |
+
namespace at::native {
|
| 10 |
+
|
| 11 |
+
enum class EmbeddingBagMode {
|
| 12 |
+
SUM = 0,
|
| 13 |
+
MEAN = 1,
|
| 14 |
+
MAX = 2,
|
| 15 |
+
};
|
| 16 |
+
|
| 17 |
+
[[maybe_unused]] static bool operator==(int64_t op1, EmbeddingBagMode op2) {
|
| 18 |
+
return op1 == static_cast<int64_t>(op2);
|
| 19 |
+
}
|
| 20 |
+
|
| 21 |
+
[[maybe_unused]] static bool operator!=(int64_t op1, EmbeddingBagMode op2) {
|
| 22 |
+
return !(op1 == op2);
|
| 23 |
+
}
|
| 24 |
+
|
| 25 |
+
void check_arguments(
|
| 26 |
+
const Tensor& weight,
|
| 27 |
+
const Tensor& indices,
|
| 28 |
+
const Tensor& offsets,
|
| 29 |
+
const int64_t mode,
|
| 30 |
+
const std::optional<Tensor>& per_sample_weights,
|
| 31 |
+
bool include_last_offset);
|
| 32 |
+
|
| 33 |
+
void make_bag_size_out(
|
| 34 |
+
Tensor& bag_size_out,
|
| 35 |
+
const Tensor& offsets,
|
| 36 |
+
const Tensor& indices,
|
| 37 |
+
const int64_t mode,
|
| 38 |
+
const bool include_last_offset,
|
| 39 |
+
const bool requires_grad);
|
| 40 |
+
|
| 41 |
+
void make_max_indices_out(
|
| 42 |
+
Tensor& max_indices_out,
|
| 43 |
+
const Tensor& weight,
|
| 44 |
+
const Tensor& indices,
|
| 45 |
+
const Tensor& offsets,
|
| 46 |
+
const Tensor& bag_size,
|
| 47 |
+
const int64_t mode,
|
| 48 |
+
bool include_last_offset);
|
| 49 |
+
|
| 50 |
+
void make_offset2bag_out(
|
| 51 |
+
Tensor& offset2bag,
|
| 52 |
+
Tensor& output,
|
| 53 |
+
const Tensor& weight,
|
| 54 |
+
const Tensor& indices,
|
| 55 |
+
const Tensor& offsets,
|
| 56 |
+
const int64_t mode,
|
| 57 |
+
const std::optional<Tensor>& per_sample_weights,
|
| 58 |
+
const int64_t padding_idx = -1);
|
| 59 |
+
|
| 60 |
+
#ifdef USE_FBGEMM
|
| 61 |
+
|
| 62 |
+
template<bool has_weight, typename TIndex, typename TData>
|
| 63 |
+
struct _CallbackAndBlockSize {
|
| 64 |
+
using TCallback = typename fbgemm::EmbeddingSpMDMKernelSignature<TData, TIndex, TIndex, TData>::Type;
|
| 65 |
+
|
| 66 |
+
int64_t blockSize = -1;
|
| 67 |
+
TCallback callback = nullptr;
|
| 68 |
+
|
| 69 |
+
static TCallback generateCallback(int64_t block_size) {
|
| 70 |
+
return fbgemm::GenerateEmbeddingSpMDM<TData, TIndex, TIndex, TData>(
|
| 71 |
+
block_size,
|
| 72 |
+
has_weight,
|
| 73 |
+
/* normalize_by_lengths */false,
|
| 74 |
+
/* prefetch */16,
|
| 75 |
+
/* is_weight_positional */false,
|
| 76 |
+
/* use_offsets */true);
|
| 77 |
+
}
|
| 78 |
+
|
| 79 |
+
_CallbackAndBlockSize() = default;
|
| 80 |
+
|
| 81 |
+
explicit _CallbackAndBlockSize(std::optional<int64_t> maybe_block_size)
|
| 82 |
+
: blockSize(maybe_block_size.value_or(-1))
|
| 83 |
+
, callback(maybe_block_size.has_value() ? generateCallback(maybe_block_size.value()) : nullptr)
|
| 84 |
+
{}
|
| 85 |
+
};
|
| 86 |
+
|
| 87 |
+
template<typename... StorageMixins>
|
| 88 |
+
struct _EmbeddingBagKernelCacheImpl : private StorageMixins... {
|
| 89 |
+
|
| 90 |
+
_EmbeddingBagKernelCacheImpl() = default;
|
| 91 |
+
// use each of the mixins to store corresponding kernel and block size
|
| 92 |
+
explicit _EmbeddingBagKernelCacheImpl(std::optional<int64_t> maybe_block_size)
|
| 93 |
+
: StorageMixins(maybe_block_size)...
|
| 94 |
+
{}
|
| 95 |
+
|
| 96 |
+
// this method is thread safe (call sites may call from different threads)
|
| 97 |
+
template<bool has_weight, typename TIndex, typename TData>
|
| 98 |
+
typename _CallbackAndBlockSize<has_weight, TIndex, TData>::TCallback
|
| 99 |
+
getCallback(int64_t block_size) const {
|
| 100 |
+
// if the cache doesn't store the kernel for the incoming block size
|
| 101 |
+
// (so it is different from the one stored in corresponding mixin)
|
| 102 |
+
// regenerate the kernel (not writing it into the cache so we avoid locks)
|
| 103 |
+
if (block_size != _CallbackAndBlockSize<has_weight, TIndex, TData>::blockSize) {
|
| 104 |
+
return _CallbackAndBlockSize<has_weight, TIndex, TData>::generateCallback(block_size);
|
| 105 |
+
}
|
| 106 |
+
// else retrieve the cached kernel from the corresponding mixin
|
| 107 |
+
return _CallbackAndBlockSize<has_weight, TIndex, TData>::callback;
|
| 108 |
+
}
|
| 109 |
+
};
|
| 110 |
+
|
| 111 |
+
// instantiate the cache with the list of storage mixins
|
| 112 |
+
// for each of the 8 _EmbeddingBagKernelCache* usages in the EmbeddingBag.cpp impl file
|
| 113 |
+
using _EmbeddingBagKernelCache = _EmbeddingBagKernelCacheImpl<
|
| 114 |
+
_CallbackAndBlockSize<true, int32_t, float>,
|
| 115 |
+
_CallbackAndBlockSize<false, int32_t, float>,
|
| 116 |
+
_CallbackAndBlockSize<true, int64_t, float>,
|
| 117 |
+
_CallbackAndBlockSize<false, int64_t, float>,
|
| 118 |
+
_CallbackAndBlockSize<true, int32_t, unsigned short>,
|
| 119 |
+
_CallbackAndBlockSize<false, int32_t, unsigned short>,
|
| 120 |
+
_CallbackAndBlockSize<true, int64_t, unsigned short>,
|
| 121 |
+
_CallbackAndBlockSize<false, int64_t, unsigned short>>;
|
| 122 |
+
#else
|
| 123 |
+
struct _EmbeddingBagKernelCache {
|
| 124 |
+
explicit _EmbeddingBagKernelCache(std::optional<int64_t> /* maybe_block_size */) {}
|
| 125 |
+
};
|
| 126 |
+
#endif
|
| 127 |
+
|
| 128 |
+
void _embedding_bag_cpu_impl_out(Tensor& output, Tensor& offset2bag,
|
| 129 |
+
Tensor& bag_size, Tensor* max_indices,
|
| 130 |
+
const Tensor &weight, const Tensor &indices,
|
| 131 |
+
const Tensor &offsets, const int64_t mode = 0,
|
| 132 |
+
const std::optional<Tensor>& per_sample_weights = std::nullopt,
|
| 133 |
+
bool include_last_offset = false,
|
| 134 |
+
int64_t padding_idx = -1,
|
| 135 |
+
_EmbeddingBagKernelCache* fbgemm_kernel_cache = nullptr);
|
| 136 |
+
|
| 137 |
+
void _embedding_bag_cpu_out(
|
| 138 |
+
at::Tensor& output,
|
| 139 |
+
at::Tensor& offset2bag,
|
| 140 |
+
at::Tensor& bag_size,
|
| 141 |
+
at::Tensor* p_max_indices,
|
| 142 |
+
const at::Tensor& weight,
|
| 143 |
+
const at::Tensor& indices,
|
| 144 |
+
const at::Tensor& offsets,
|
| 145 |
+
const bool scale_grad_by_freq,
|
| 146 |
+
const int64_t mode,
|
| 147 |
+
const bool sparse,
|
| 148 |
+
const std::optional<at::Tensor>& per_sample_weights,
|
| 149 |
+
const bool include_last_offset,
|
| 150 |
+
const std::optional<int64_t>& padding_idx,
|
| 151 |
+
_EmbeddingBagKernelCache* fbgemm_kernel_cache = nullptr);
|
| 152 |
+
|
| 153 |
+
} // namespace at::native
|
.venv/lib/python3.11/site-packages/torch/include/ATen/native/ForeachUtils.h
ADDED
|
@@ -0,0 +1,396 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/Device.h>
|
| 4 |
+
#include <ATen/Dispatch.h>
|
| 5 |
+
#include <ATen/ScalarType.h>
|
| 6 |
+
#include <ATen/core/Tensor.h>
|
| 7 |
+
#include <ATen/native/utils/ParamsHash.h>
|
| 8 |
+
#include <c10/util/Exception.h>
|
| 9 |
+
#include <c10/util/irange.h>
|
| 10 |
+
|
| 11 |
+
#ifndef AT_PER_OPERATOR_HEADERS
|
| 12 |
+
#include <ATen/NativeFunctions.h>
|
| 13 |
+
#else
|
| 14 |
+
#include <ATen/ops/result_type_native.h>
|
| 15 |
+
#endif
|
| 16 |
+
|
| 17 |
+
#include <unordered_map>
|
| 18 |
+
#include <vector>
|
| 19 |
+
|
| 20 |
+
namespace at::native {
|
| 21 |
+
namespace {
|
| 22 |
+
// Check if tensor list has either a boolean tensor or a integer tensor
|
| 23 |
+
inline bool has_integral_tensor(TensorList tensors, const bool includeBool) {
|
| 24 |
+
return std::any_of(
|
| 25 |
+
tensors.begin(), tensors.end(), [&includeBool](const auto& t) {
|
| 26 |
+
return at::isIntegralType(t.scalar_type(), includeBool);
|
| 27 |
+
});
|
| 28 |
+
}
|
| 29 |
+
// check if tensor list has bool tensors
|
| 30 |
+
inline bool has_bool_tensor(TensorList tensors) {
|
| 31 |
+
return std::any_of(tensors.begin(), tensors.end(), [](const auto& t) -> bool {
|
| 32 |
+
return t.scalar_type() == ScalarType::Bool;
|
| 33 |
+
});
|
| 34 |
+
}
|
| 35 |
+
|
| 36 |
+
// Check foreach API restrictions
|
| 37 |
+
// - Tensor lists must be non-empty.
|
| 38 |
+
// - All TensorLists and ScalarLists must have the same number of elements.
|
| 39 |
+
// - Corresponding tensors must have the same size.
|
| 40 |
+
inline void check_foreach_api_restrictions(TensorList tensors) {
|
| 41 |
+
TORCH_CHECK(!tensors.empty(), "Tensor list must have at least one tensor.");
|
| 42 |
+
}
|
| 43 |
+
|
| 44 |
+
inline void check_foreach_api_restrictions(
|
| 45 |
+
TensorList tensors,
|
| 46 |
+
ArrayRef<Scalar> scalars) {
|
| 47 |
+
check_foreach_api_restrictions(tensors);
|
| 48 |
+
TORCH_CHECK(
|
| 49 |
+
tensors.size() == scalars.size(),
|
| 50 |
+
"Tensor list must have same number of elements as scalar list.");
|
| 51 |
+
}
|
| 52 |
+
|
| 53 |
+
inline void check_foreach_api_restrictions(
|
| 54 |
+
TensorList tensors1,
|
| 55 |
+
TensorList tensors2) {
|
| 56 |
+
TORCH_CHECK(!tensors1.empty(), "Tensor list must have at least one tensor.");
|
| 57 |
+
TORCH_CHECK(!tensors2.empty(), "Tensor list must have at least one tensor.");
|
| 58 |
+
TORCH_CHECK(
|
| 59 |
+
tensors1.size() == tensors2.size(),
|
| 60 |
+
"Tensor lists must have the same number of tensors, got ",
|
| 61 |
+
tensors1.size(),
|
| 62 |
+
" and ",
|
| 63 |
+
tensors2.size());
|
| 64 |
+
}
|
| 65 |
+
|
| 66 |
+
inline void check_foreach_api_restrictions(
|
| 67 |
+
TensorList tensors1,
|
| 68 |
+
TensorList tensors2,
|
| 69 |
+
TensorList tensors3) {
|
| 70 |
+
TORCH_CHECK(!tensors1.empty(), "Tensor list must have at least one tensor.");
|
| 71 |
+
TORCH_CHECK(!tensors2.empty(), "Tensor list must have at least one tensor.");
|
| 72 |
+
TORCH_CHECK(!tensors3.empty(), "Tensor list must have at least one tensor.");
|
| 73 |
+
TORCH_CHECK(
|
| 74 |
+
tensors1.size() == tensors2.size(),
|
| 75 |
+
"Tensor lists must have the same number of tensors, got ",
|
| 76 |
+
tensors1.size(),
|
| 77 |
+
" and ",
|
| 78 |
+
tensors2.size());
|
| 79 |
+
TORCH_CHECK(
|
| 80 |
+
tensors1.size() == tensors3.size(),
|
| 81 |
+
"Tensor lists must have the same number of tensors, got ",
|
| 82 |
+
tensors1.size(),
|
| 83 |
+
" and ",
|
| 84 |
+
tensors3.size());
|
| 85 |
+
}
|
| 86 |
+
|
| 87 |
+
inline void check_foreach_api_restrictions(
|
| 88 |
+
TensorList tensors1,
|
| 89 |
+
TensorList tensors2,
|
| 90 |
+
TensorList tensors3,
|
| 91 |
+
ArrayRef<Scalar> scalars) {
|
| 92 |
+
check_foreach_api_restrictions(tensors1, tensors2, tensors3);
|
| 93 |
+
TORCH_CHECK(
|
| 94 |
+
tensors1.size() == scalars.size(),
|
| 95 |
+
"Tensor list must have same number of elements as scalar list, got ",
|
| 96 |
+
tensors1.size(),
|
| 97 |
+
" and ",
|
| 98 |
+
scalars.size());
|
| 99 |
+
}
|
| 100 |
+
|
| 101 |
+
// Helper function called in check_fast_path_restrictions to check whether all
|
| 102 |
+
// corresponding tensors (aligning in index across the tensorLists) share the
|
| 103 |
+
// same device and dtype.
|
| 104 |
+
inline bool _check_tensors_share_device_and_dtype(
|
| 105 |
+
ArrayRef<TensorList> tensorLists,
|
| 106 |
+
const bool skip_dtype_check = false) {
|
| 107 |
+
const auto expected_dtype = tensorLists[0][0].dtype();
|
| 108 |
+
const auto expected_device = tensorLists[0][0].device();
|
| 109 |
+
|
| 110 |
+
auto is_tensor_okay = [&](const Tensor& tensor) {
|
| 111 |
+
return (skip_dtype_check || tensor.dtype() == expected_dtype) &&
|
| 112 |
+
tensor.device() == expected_device && tensor.layout() == at::kStrided &&
|
| 113 |
+
tensor.is_non_overlapping_and_dense();
|
| 114 |
+
};
|
| 115 |
+
|
| 116 |
+
for (const auto& tensorList : tensorLists) {
|
| 117 |
+
for (const auto& tensor : tensorList) {
|
| 118 |
+
if (!is_tensor_okay(tensor)) {
|
| 119 |
+
return false;
|
| 120 |
+
}
|
| 121 |
+
}
|
| 122 |
+
}
|
| 123 |
+
|
| 124 |
+
return true;
|
| 125 |
+
}
|
| 126 |
+
|
| 127 |
+
// Helper function called in check_fast_path_restrictions to check if
|
| 128 |
+
// corresponding tensors in tensor lists have the same sizes and strides.
|
| 129 |
+
inline bool _check_tensors_share_sizes_and_strides(
|
| 130 |
+
ArrayRef<TensorList> tensorLists) {
|
| 131 |
+
auto is_diff_stride = [](const IntArrayRef& size,
|
| 132 |
+
const IntArrayRef& left_stride,
|
| 133 |
+
const IntArrayRef& right_stride) -> bool {
|
| 134 |
+
const size_t size_size = size.size();
|
| 135 |
+
for (const auto dim : c10::irange(size_size)) {
|
| 136 |
+
if (size[dim] == 1)
|
| 137 |
+
continue;
|
| 138 |
+
if (left_stride[dim] != right_stride[dim]) {
|
| 139 |
+
return true;
|
| 140 |
+
}
|
| 141 |
+
}
|
| 142 |
+
return false;
|
| 143 |
+
};
|
| 144 |
+
for (const auto i : c10::irange(1, tensorLists.size())) {
|
| 145 |
+
for (const auto j : c10::irange(tensorLists[0].size())) {
|
| 146 |
+
if (tensorLists[0][j].sizes() != tensorLists[i][j].sizes() ||
|
| 147 |
+
is_diff_stride(
|
| 148 |
+
tensorLists[0][j].sizes(),
|
| 149 |
+
tensorLists[0][j].strides(),
|
| 150 |
+
tensorLists[i][j].strides())) {
|
| 151 |
+
return false;
|
| 152 |
+
}
|
| 153 |
+
}
|
| 154 |
+
}
|
| 155 |
+
|
| 156 |
+
return true;
|
| 157 |
+
}
|
| 158 |
+
|
| 159 |
+
// Helper function called in check_fast_path_restrictions to check whether
|
| 160 |
+
// all tensors type promote properly with the scalars in scalarList. This
|
| 161 |
+
// function assumes that _check_tensors_share_device_and_dtype has already been
|
| 162 |
+
// called so that all corresponding tensors in tensorLists have the same dtype.
|
| 163 |
+
// Then, it is sufficient to check the type promotion with just one tensorList.
|
| 164 |
+
inline bool _check_tensors_do_type_promotion_with_scalars(
|
| 165 |
+
TensorList tensorList,
|
| 166 |
+
ArrayRef<Scalar> scalarList = {},
|
| 167 |
+
bool does_op_promote_integer_inputs_to_float = false) {
|
| 168 |
+
for (const auto i : c10::irange(tensorList.size())) {
|
| 169 |
+
// For division, integer inputs will result in float.
|
| 170 |
+
if (does_op_promote_integer_inputs_to_float) {
|
| 171 |
+
if (at::isIntegralType(
|
| 172 |
+
tensorList[i].scalar_type(), /*includeBool*/ true)) {
|
| 173 |
+
return false;
|
| 174 |
+
}
|
| 175 |
+
}
|
| 176 |
+
if (!scalarList.empty()) {
|
| 177 |
+
const auto& scalar =
|
| 178 |
+
scalarList.size() == 1 ? scalarList[0] : scalarList[i];
|
| 179 |
+
const auto& tensor = tensorList[i];
|
| 180 |
+
// note(mkozuki): This check might be responsible for
|
| 181 |
+
// `_foreach_add(bool_tensors, bool_tensors)` being pushed to slow path.
|
| 182 |
+
if (tensor.scalar_type() != at::native::result_type(scalar, tensor)) {
|
| 183 |
+
return false;
|
| 184 |
+
}
|
| 185 |
+
}
|
| 186 |
+
}
|
| 187 |
+
|
| 188 |
+
return true;
|
| 189 |
+
}
|
| 190 |
+
|
| 191 |
+
// To go via 'fast' path, several conditions must be satisfied
|
| 192 |
+
// - All tensors in all lists must have the same dtype.
|
| 193 |
+
// - All tensors must be on the same device
|
| 194 |
+
// - All tensors must have strided layout
|
| 195 |
+
// - All tensors must be non-overlapping and dense
|
| 196 |
+
// - Resulting tensor must have the same dtype as the input one
|
| 197 |
+
|
| 198 |
+
// [note: what's ``does_op_promote_integer_inputs_to_float=true``?]
|
| 199 |
+
// ``does_op_promote_integer_inputs_to_float=true`` means that the result of
|
| 200 |
+
// the op will be float even if inputs are integer or boolean, which
|
| 201 |
+
// currently fast path does not support. In short, this flag, when
|
| 202 |
+
// turned on, gatekeeps the op from going down the fastpath.
|
| 203 |
+
|
| 204 |
+
// Please, make sure to call check_foreach_api_restrictions before calling this
|
| 205 |
+
// method. There is a set of preconditions that have to be satisfied.
|
| 206 |
+
inline bool check_fast_path_restrictions(
|
| 207 |
+
ArrayRef<TensorList> tensorLists,
|
| 208 |
+
ArrayRef<Scalar> scalarList = {},
|
| 209 |
+
bool does_op_promote_integer_inputs_to_float = false) {
|
| 210 |
+
return _check_tensors_share_device_and_dtype(tensorLists) &&
|
| 211 |
+
_check_tensors_share_sizes_and_strides(tensorLists) &&
|
| 212 |
+
_check_tensors_do_type_promotion_with_scalars(
|
| 213 |
+
tensorLists[0],
|
| 214 |
+
scalarList,
|
| 215 |
+
does_op_promote_integer_inputs_to_float);
|
| 216 |
+
}
|
| 217 |
+
|
| 218 |
+
inline std::vector<c10::Scalar> convert_tensor_to_scalar_list(
|
| 219 |
+
const Tensor& scalarList_,
|
| 220 |
+
int64_t expect_length) {
|
| 221 |
+
std::vector<c10::Scalar> scalarList;
|
| 222 |
+
TORCH_CHECK(
|
| 223 |
+
scalarList_.device() == c10::kCPU,
|
| 224 |
+
"Expected scalars to be on CPU, got ",
|
| 225 |
+
scalarList_.device(),
|
| 226 |
+
" instead.");
|
| 227 |
+
TORCH_CHECK(
|
| 228 |
+
scalarList_.is_contiguous(), "Expected scalars to be contiguous.");
|
| 229 |
+
TORCH_CHECK(
|
| 230 |
+
scalarList_.dim() == 1,
|
| 231 |
+
"Expected packed scalar Tensor to be of dimension 1. Got ",
|
| 232 |
+
scalarList_.dim(),
|
| 233 |
+
" instead.");
|
| 234 |
+
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(
|
| 235 |
+
kComplexHalf,
|
| 236 |
+
kHalf,
|
| 237 |
+
kBool,
|
| 238 |
+
kBFloat16,
|
| 239 |
+
scalarList_.scalar_type(),
|
| 240 |
+
"convert_tensor_to_scalar_list",
|
| 241 |
+
[&]() {
|
| 242 |
+
const scalar_t* scalar_data = scalarList_.const_data_ptr<scalar_t>();
|
| 243 |
+
TORCH_CHECK(
|
| 244 |
+
(expect_length == scalarList_.size(0)),
|
| 245 |
+
"Expected length of scalars to match input of length ",
|
| 246 |
+
expect_length,
|
| 247 |
+
" but got ",
|
| 248 |
+
scalarList_.size(0),
|
| 249 |
+
" instead.");
|
| 250 |
+
for (int64_t i = 0; i < scalarList_.size(0); i++) {
|
| 251 |
+
scalarList.emplace_back(scalar_data[i]);
|
| 252 |
+
}
|
| 253 |
+
});
|
| 254 |
+
return scalarList;
|
| 255 |
+
}
|
| 256 |
+
|
| 257 |
+
// see: [note: what's ``does_op_promote_integer_inputs_to_float=true``?]
|
| 258 |
+
inline bool can_use_fast_route(
|
| 259 |
+
ArrayRef<TensorList> tensorLists,
|
| 260 |
+
ArrayRef<Scalar> scalarList = {},
|
| 261 |
+
bool does_op_promote_integer_inputs_to_float = false) {
|
| 262 |
+
return check_fast_path_restrictions(
|
| 263 |
+
tensorLists, scalarList, does_op_promote_integer_inputs_to_float);
|
| 264 |
+
}
|
| 265 |
+
|
| 266 |
+
// see: [note: what's ``does_op_promote_integer_inputs_to_float=true``?]
|
| 267 |
+
inline bool can_use_fast_route(
|
| 268 |
+
TensorList tensors1,
|
| 269 |
+
TensorList tensors2,
|
| 270 |
+
bool does_op_promote_integer_inputs_to_float = false) {
|
| 271 |
+
return can_use_fast_route(
|
| 272 |
+
{tensors1, tensors2}, {}, does_op_promote_integer_inputs_to_float);
|
| 273 |
+
}
|
| 274 |
+
|
| 275 |
+
using DeviceDtypeKey = std::pair<at::Device, at::ScalarType>;
|
| 276 |
+
using IndicesT = std::vector<size_t>;
|
| 277 |
+
using nested_optional_tensorvec_t =
|
| 278 |
+
std::vector<std::vector<std::optional<at::Tensor>>>;
|
| 279 |
+
using TensorsAndIndicesT = std::pair<nested_optional_tensorvec_t, IndicesT>;
|
| 280 |
+
using FlatMap = std::unordered_map<
|
| 281 |
+
DeviceDtypeKey,
|
| 282 |
+
TensorsAndIndicesT,
|
| 283 |
+
ParamsHash<DeviceDtypeKey>>;
|
| 284 |
+
|
| 285 |
+
inline FlatMap _group_tensors_by_first_tensors_device_and_dtype(
|
| 286 |
+
const nested_optional_tensorvec_t& nested_tensorlist,
|
| 287 |
+
const bool with_indices) {
|
| 288 |
+
FlatMap grouped_tensors_with_indices;
|
| 289 |
+
|
| 290 |
+
TORCH_CHECK(!nested_tensorlist.empty());
|
| 291 |
+
TORCH_CHECK(!nested_tensorlist[0].empty());
|
| 292 |
+
const auto num_lists = nested_tensorlist.size();
|
| 293 |
+
const auto num_tensors = nested_tensorlist[0].size();
|
| 294 |
+
|
| 295 |
+
TORCH_CHECK(std::all_of(
|
| 296 |
+
nested_tensorlist.cbegin(),
|
| 297 |
+
nested_tensorlist.cend(),
|
| 298 |
+
[&](const auto& tensorlist) -> bool {
|
| 299 |
+
// note(crcrpar): Allow empty tensorlists following
|
| 300 |
+
// ref:
|
| 301 |
+
// https://github.com/pytorch/pytorch/blob/85885301fd3c6adb8b9dc3cf7afadf6945566684/torch/utils/_foreach_utils.py#L21-L24
|
| 302 |
+
return tensorlist.size() == num_tensors || tensorlist.size() == 0;
|
| 303 |
+
}));
|
| 304 |
+
|
| 305 |
+
for (const auto& tensor_index : c10::irange(num_tensors)) {
|
| 306 |
+
const auto key = [&]() -> DeviceDtypeKey {
|
| 307 |
+
const auto t = nested_tensorlist[0][tensor_index];
|
| 308 |
+
TORCH_CHECK(
|
| 309 |
+
t.has_value(),
|
| 310 |
+
"Tensors of the first list of nested Tensor lists are supposed to be defined but ",
|
| 311 |
+
"the ",
|
| 312 |
+
tensor_index,
|
| 313 |
+
"-th Tensor is not.");
|
| 314 |
+
return {t->device(), t->scalar_type()};
|
| 315 |
+
}();
|
| 316 |
+
TORCH_CHECK(
|
| 317 |
+
std::all_of(
|
| 318 |
+
nested_tensorlist.cbegin(),
|
| 319 |
+
nested_tensorlist.cend(),
|
| 320 |
+
[&](const auto& tensorlist) -> bool {
|
| 321 |
+
if (tensorlist.size() == 0) {
|
| 322 |
+
return true;
|
| 323 |
+
}
|
| 324 |
+
const auto& tensor = tensorlist[tensor_index];
|
| 325 |
+
// note(crcrpar): Currently the scope of this function is
|
| 326 |
+
// optimizers so there could be `state_steps` and other scalars
|
| 327 |
+
// whose elements are float tensors no matter what the parameter's
|
| 328 |
+
// dtype is.
|
| 329 |
+
if (!tensor.has_value()) {
|
| 330 |
+
return true;
|
| 331 |
+
} else {
|
| 332 |
+
const auto s = tensor->scalar_type();
|
| 333 |
+
const auto d = tensor->device();
|
| 334 |
+
// Note: `step` or `state_step` is float32 by default.
|
| 335 |
+
if (key.first == d) {
|
| 336 |
+
return key.second == s || s == at::ScalarType::Float ||
|
| 337 |
+
s == at::ScalarType::Double;
|
| 338 |
+
} else if (d.is_cpu()) {
|
| 339 |
+
// note(crcrpar): There are some test cases (e.g.
|
| 340 |
+
// TestOptim::test_adam) where state_steps are on CPU and the
|
| 341 |
+
// others are on CUDA. Currently a state_step Tensor has the
|
| 342 |
+
// dtype of float.
|
| 343 |
+
return s == at::ScalarType::Float ||
|
| 344 |
+
s == at::ScalarType::Double;
|
| 345 |
+
} else {
|
| 346 |
+
return false;
|
| 347 |
+
}
|
| 348 |
+
}
|
| 349 |
+
}),
|
| 350 |
+
"Tensors of the same index must be on the same device and the same dtype except `step` tensors that can be CPU and float32/64 notwithstanding");
|
| 351 |
+
if (!grouped_tensors_with_indices.count(key)) {
|
| 352 |
+
grouped_tensors_with_indices.insert(
|
| 353 |
+
{key,
|
| 354 |
+
TensorsAndIndicesT{
|
| 355 |
+
[&]() -> nested_optional_tensorvec_t {
|
| 356 |
+
nested_optional_tensorvec_t nested_tensorvec;
|
| 357 |
+
nested_tensorvec.reserve(num_lists);
|
| 358 |
+
for (const auto& i : c10::irange(num_lists)) {
|
| 359 |
+
std::vector<std::optional<at::Tensor>> tensors;
|
| 360 |
+
if (!nested_tensorlist[i].empty()) {
|
| 361 |
+
// NB: num_tensors is the max possible length for any of
|
| 362 |
+
// the inner lists of tensor references. Reserving the max
|
| 363 |
+
// trades memory for perf. This should not have significant
|
| 364 |
+
// impact.
|
| 365 |
+
tensors.reserve(num_tensors);
|
| 366 |
+
}
|
| 367 |
+
nested_tensorvec.emplace_back(tensors);
|
| 368 |
+
}
|
| 369 |
+
return nested_tensorvec;
|
| 370 |
+
}(),
|
| 371 |
+
[&]() -> IndicesT {
|
| 372 |
+
if (!with_indices) {
|
| 373 |
+
return {};
|
| 374 |
+
} else {
|
| 375 |
+
IndicesT indices;
|
| 376 |
+
indices.reserve(num_tensors);
|
| 377 |
+
return indices;
|
| 378 |
+
}
|
| 379 |
+
}()}});
|
| 380 |
+
}
|
| 381 |
+
for (const auto& list_index : c10::irange(num_lists)) {
|
| 382 |
+
if (!nested_tensorlist[list_index].empty()) {
|
| 383 |
+
grouped_tensors_with_indices[key].first[list_index].emplace_back(
|
| 384 |
+
nested_tensorlist[list_index][tensor_index]);
|
| 385 |
+
}
|
| 386 |
+
}
|
| 387 |
+
if (with_indices) {
|
| 388 |
+
grouped_tensors_with_indices[key].second.emplace_back(tensor_index);
|
| 389 |
+
}
|
| 390 |
+
}
|
| 391 |
+
|
| 392 |
+
return grouped_tensors_with_indices;
|
| 393 |
+
}
|
| 394 |
+
|
| 395 |
+
} // namespace
|
| 396 |
+
} // namespace at::native
|
.venv/lib/python3.11/site-packages/torch/include/ATen/native/FusedAdagrad.h
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <ATen/core/Tensor.h>
|
| 2 |
+
#include <ATen/native/DispatchStub.h>
|
| 3 |
+
|
| 4 |
+
namespace at::native {
|
| 5 |
+
|
| 6 |
+
using fused_adagrad_fn = void (*)(
|
| 7 |
+
const at::Tensor& param,
|
| 8 |
+
const at::Tensor& grad,
|
| 9 |
+
const at::Tensor& state_sum,
|
| 10 |
+
const at::Tensor& state_step,
|
| 11 |
+
const double lr,
|
| 12 |
+
const double lr_decay,
|
| 13 |
+
const double weight_decay,
|
| 14 |
+
const double eps,
|
| 15 |
+
const bool maximize,
|
| 16 |
+
const float* grad_scale_ptr);
|
| 17 |
+
|
| 18 |
+
DECLARE_DISPATCH(fused_adagrad_fn, fused_adagrad_stub);
|
| 19 |
+
|
| 20 |
+
} // namespace at::native
|
.venv/lib/python3.11/site-packages/torch/include/ATen/native/FusedAdam.h
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <ATen/core/Tensor.h>
|
| 2 |
+
#include <ATen/native/DispatchStub.h>
|
| 3 |
+
|
| 4 |
+
namespace at::native {
|
| 5 |
+
|
| 6 |
+
enum class ADAM_MODE : uint8_t { ORIGINAL = 0, ADAMW = 1 };
|
| 7 |
+
|
| 8 |
+
using fused_adam_fn = void (*)(
|
| 9 |
+
const at::Tensor& param,
|
| 10 |
+
const at::Tensor& grad,
|
| 11 |
+
const at::Tensor& exp_avg,
|
| 12 |
+
const at::Tensor& exp_avg_sq,
|
| 13 |
+
const at::Tensor& max_exp_avg_sq,
|
| 14 |
+
const at::Tensor& state_step,
|
| 15 |
+
const double lr,
|
| 16 |
+
const double beta1,
|
| 17 |
+
const double beta2,
|
| 18 |
+
const double weight_decay,
|
| 19 |
+
const double eps,
|
| 20 |
+
const bool amsgrad,
|
| 21 |
+
const bool maximize,
|
| 22 |
+
const float* grad_scale_ptr,
|
| 23 |
+
const ADAM_MODE);
|
| 24 |
+
|
| 25 |
+
DECLARE_DISPATCH(fused_adam_fn, fused_adam_stub);
|
| 26 |
+
|
| 27 |
+
} // namespace at::native
|
.venv/lib/python3.11/site-packages/torch/include/ATen/native/FusedSGD.h
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <ATen/core/Tensor.h>
|
| 2 |
+
#include <ATen/native/DispatchStub.h>
|
| 3 |
+
|
| 4 |
+
namespace at::native {
|
| 5 |
+
|
| 6 |
+
using fused_sgd_fn = void (*)(
|
| 7 |
+
const at::Tensor& param,
|
| 8 |
+
const at::Tensor& grad,
|
| 9 |
+
const at::Tensor& momentum_buffer,
|
| 10 |
+
const double weight_decay,
|
| 11 |
+
const double momentum,
|
| 12 |
+
const double lr,
|
| 13 |
+
const double dampening,
|
| 14 |
+
const bool nesterov,
|
| 15 |
+
const bool maximize,
|
| 16 |
+
const bool is_first_step,
|
| 17 |
+
const float* grad_scale_ptr);
|
| 18 |
+
|
| 19 |
+
DECLARE_DISPATCH(fused_sgd_fn, fused_sgd_stub);
|
| 20 |
+
|
| 21 |
+
} // namespace at::native
|
.venv/lib/python3.11/site-packages/torch/include/ATen/native/GridSamplerUtils.h
ADDED
|
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
// See NOTE: [Tensor vs. TensorBase]
|
| 4 |
+
// https://github.com/pytorch/pytorch/pull/66979
|
| 5 |
+
#include <ATen/core/TensorBase.h>
|
| 6 |
+
#include <ATen/native/TensorProperties.h>
|
| 7 |
+
#include <ATen/native/CanUse32BitIndexMath.h>
|
| 8 |
+
|
| 9 |
+
namespace at::native {
|
| 10 |
+
|
| 11 |
+
namespace detail {
|
| 12 |
+
|
| 13 |
+
enum class GridSamplerInterpolation {Bilinear, Nearest, Bicubic};
|
| 14 |
+
enum class GridSamplerPadding {Zeros, Border, Reflection};
|
| 15 |
+
|
| 16 |
+
} // namespace detail
|
| 17 |
+
|
| 18 |
+
using detail::GridSamplerInterpolation;
|
| 19 |
+
using detail::GridSamplerPadding;
|
| 20 |
+
|
| 21 |
+
// See NOTE [ grid_sampler Native Functions ].
|
| 22 |
+
inline void check_grid_sampler_common(
|
| 23 |
+
const TensorBase& input,
|
| 24 |
+
const TensorBase& grid
|
| 25 |
+
) {
|
| 26 |
+
auto input_opt = input.options();
|
| 27 |
+
auto grid_opt = grid.options();
|
| 28 |
+
|
| 29 |
+
TORCH_CHECK(
|
| 30 |
+
input.defined(),
|
| 31 |
+
"grid_sampler(): expected input to not be undefined");
|
| 32 |
+
TORCH_CHECK(
|
| 33 |
+
grid.defined(),
|
| 34 |
+
"grid_sampler(): expected grid to not be undefined");
|
| 35 |
+
TORCH_CHECK(
|
| 36 |
+
input_opt.device() == grid_opt.device(),
|
| 37 |
+
"grid_sampler(): expected input and grid to be on same device, but input "
|
| 38 |
+
"is on ", input_opt.device(), " and grid is on ", grid_opt.device());
|
| 39 |
+
TORCH_CHECK(
|
| 40 |
+
input_opt.layout() == kStrided && grid_opt.layout() == kStrided,
|
| 41 |
+
"grid_sampler(): expected input and grid to have torch.strided layout, but "
|
| 42 |
+
"input has ", input_opt.layout(), " and grid has ", grid_opt.layout());
|
| 43 |
+
TORCH_CHECK(
|
| 44 |
+
input.size(0) == grid.size(0),
|
| 45 |
+
"grid_sampler(): expected grid and input to have same batch size, but got "
|
| 46 |
+
"input with sizes ", input.sizes(), " and grid with sizes ", grid.sizes());
|
| 47 |
+
TORCH_CHECK(
|
| 48 |
+
grid.size(-1) == input.dim() - 2,
|
| 49 |
+
"grid_sampler(): expected grid to have size ", input.dim() - 2, " in last "
|
| 50 |
+
"dimension, but got grid with sizes ", grid.sizes());
|
| 51 |
+
|
| 52 |
+
for (const auto i : c10::irange(2, input.dim())) {
|
| 53 |
+
TORCH_CHECK(input.size(i) > 0,
|
| 54 |
+
"grid_sampler(): expected input to have non-empty spatial dimensions, "
|
| 55 |
+
"but input has sizes ", input.sizes(), " with dimension ", i, " being "
|
| 56 |
+
"empty");
|
| 57 |
+
}
|
| 58 |
+
}
|
| 59 |
+
|
| 60 |
+
// See NOTE [ grid_sampler Native Functions ].
|
| 61 |
+
inline void check_grid_sampler_2d(
|
| 62 |
+
const TensorBase& input,
|
| 63 |
+
const TensorBase& grid
|
| 64 |
+
) {
|
| 65 |
+
TORCH_CHECK(
|
| 66 |
+
input.dim() == 4 && input.dim() == grid.dim(),
|
| 67 |
+
"grid_sampler(): expected 4D input and grid with same number of "
|
| 68 |
+
"dimensions, but got input with sizes ", input.sizes(),
|
| 69 |
+
" and grid with sizes ", grid.sizes());
|
| 70 |
+
}
|
| 71 |
+
|
| 72 |
+
// See NOTE [ grid_sampler Native Functions ].
|
| 73 |
+
inline void check_grid_sampler_3d(
|
| 74 |
+
const TensorBase& input,
|
| 75 |
+
const TensorBase& grid,
|
| 76 |
+
int64_t interpolation_mode
|
| 77 |
+
) {
|
| 78 |
+
TORCH_CHECK(
|
| 79 |
+
input.dim() == 5 && input.dim() == grid.dim(),
|
| 80 |
+
"grid_sampler(): expected 5D input and grid with same number of "
|
| 81 |
+
"dimensions, but got input with sizes ", input.sizes(),
|
| 82 |
+
" and grid with sizes ", grid.sizes());
|
| 83 |
+
TORCH_CHECK(
|
| 84 |
+
!(input.dim() == 5 &&
|
| 85 |
+
static_cast<GridSamplerInterpolation>(interpolation_mode) ==
|
| 86 |
+
GridSamplerInterpolation::Bicubic),
|
| 87 |
+
"grid_sampler(): bicubic interpolation only supports 4D input");
|
| 88 |
+
}
|
| 89 |
+
|
| 90 |
+
// See NOTE [ grid_sampler Native Functions ].
|
| 91 |
+
// cudnn does not support inputs larger than 1024.
|
| 92 |
+
inline bool cond_cudnn_grid_sampler(
|
| 93 |
+
const TensorBase& input,
|
| 94 |
+
const TensorBase& grid
|
| 95 |
+
) {
|
| 96 |
+
return (
|
| 97 |
+
at::native::cudnn_is_acceptable(input) &&
|
| 98 |
+
at::native::cudnn_is_acceptable(grid) &&
|
| 99 |
+
at::native::canUse32BitIndexMath(input) &&
|
| 100 |
+
at::native::canUse32BitIndexMath(grid) &&
|
| 101 |
+
input.dim() == 4 &&
|
| 102 |
+
input.sym_size(1) <= 1024);
|
| 103 |
+
}
|
| 104 |
+
|
| 105 |
+
} // namespace at::native
|
.venv/lib/python3.11/site-packages/torch/include/ATen/native/Histogram.h
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/core/Tensor.h>
|
| 4 |
+
#include <ATen/native/DispatchStub.h>
|
| 5 |
+
|
| 6 |
+
namespace at::native {
|
| 7 |
+
|
| 8 |
+
using histogramdd_fn = void(*)(const Tensor&, const std::optional<Tensor>&, bool, Tensor&, const TensorList&);
|
| 9 |
+
using histogramdd_linear_fn = void(*)(const Tensor&, const std::optional<Tensor>&, bool, Tensor&, const TensorList&, bool);
|
| 10 |
+
using histogram_select_outer_bin_edges_fn = void(*)(const Tensor& input, const int64_t N, std::vector<double> &leftmost_edges, std::vector<double> &rightmost_edges);
|
| 11 |
+
|
| 12 |
+
DECLARE_DISPATCH(histogramdd_fn, histogramdd_stub);
|
| 13 |
+
DECLARE_DISPATCH(histogramdd_linear_fn, histogramdd_linear_stub);
|
| 14 |
+
DECLARE_DISPATCH(histogram_select_outer_bin_edges_fn, histogram_select_outer_bin_edges_stub);
|
| 15 |
+
|
| 16 |
+
} // namespace at::native
|
.venv/lib/python3.11/site-packages/torch/include/ATen/native/IndexKernel.h
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
#include <ATen/native/DispatchStub.h>
|
| 3 |
+
#include <c10/util/ArrayRef.h>
|
| 4 |
+
|
| 5 |
+
namespace at {
|
| 6 |
+
class Tensor;
|
| 7 |
+
class TensorBase;
|
| 8 |
+
struct TensorIterator;
|
| 9 |
+
struct TensorIteratorBase;
|
| 10 |
+
}
|
| 11 |
+
|
| 12 |
+
namespace c10 {
|
| 13 |
+
class Scalar;
|
| 14 |
+
}
|
| 15 |
+
|
| 16 |
+
namespace at::native {
|
| 17 |
+
|
| 18 |
+
using index_fn = void(*)(TensorIteratorBase &, IntArrayRef indexed_sizes, IntArrayRef indexed_strides);
|
| 19 |
+
using index_fill_fn = void(*)(TensorIterator & iter, int64_t dim, int64_t self_dim_size, int64_t self_dim_stride, const Scalar& source);
|
| 20 |
+
using index_copy_fn = void(*)(TensorIterator & iter, int64_t dim, int64_t self_dim_size, int64_t self_dim_stride);
|
| 21 |
+
using index_put_fn = void(*)(TensorIterator &, IntArrayRef indexed_sizes, IntArrayRef indexed_strides, bool accumulate);
|
| 22 |
+
using put_fn = void(*)(TensorIterator & iter, const TensorBase& self, const bool accumulate);
|
| 23 |
+
using take_fn = void(*)(TensorIterator & iter, const TensorBase& input);
|
| 24 |
+
using flip_fn = void(*)(TensorIterator &, const bool);
|
| 25 |
+
using masked_fill_fn = void(*)(TensorIterator &, const Scalar& scalar);
|
| 26 |
+
using masked_select_fn = void(*)(TensorIterator &, int64_t orig_stride);
|
| 27 |
+
using masked_scatter_fn = void(*)(TensorIterator &, const TensorBase &);
|
| 28 |
+
|
| 29 |
+
DECLARE_DISPATCH(index_fn, index_stub);
|
| 30 |
+
DECLARE_DISPATCH(index_fill_fn, index_fill_stub);
|
| 31 |
+
DECLARE_DISPATCH(index_copy_fn, index_copy_stub);
|
| 32 |
+
DECLARE_DISPATCH(index_put_fn, index_put_stub);
|
| 33 |
+
DECLARE_DISPATCH(put_fn, put_stub);
|
| 34 |
+
DECLARE_DISPATCH(take_fn, take_stub);
|
| 35 |
+
DECLARE_DISPATCH(flip_fn, flip_stub);
|
| 36 |
+
DECLARE_DISPATCH(masked_fill_fn, masked_fill_stub);
|
| 37 |
+
DECLARE_DISPATCH(masked_select_fn, masked_select_serial_stub);
|
| 38 |
+
DECLARE_DISPATCH(masked_select_fn, masked_select_stub);
|
| 39 |
+
DECLARE_DISPATCH(masked_scatter_fn, masked_scatter_stub);
|
| 40 |
+
|
| 41 |
+
} // namespace at::native
|
.venv/lib/python3.11/site-packages/torch/include/ATen/native/IndexingUtils.h
ADDED
|
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
#include <ATen/ExpandUtils.h>
|
| 3 |
+
#include <ATen/native/CanUse32BitIndexMath.h>
|
| 4 |
+
#include <ATen/native/TensorIterator.h>
|
| 5 |
+
#include <ATen/core/IListRef.h>
|
| 6 |
+
#include <c10/util/irange.h>
|
| 7 |
+
|
| 8 |
+
namespace at::native {
|
| 9 |
+
|
| 10 |
+
[[noreturn]]
|
| 11 |
+
static void invalid_mask(const Tensor & self, int64_t idx, const Tensor & mask, int64_t maskIdx) {
|
| 12 |
+
TORCH_CHECK_INDEX(false, "The shape of the mask ", mask.sizes(), " at index ", maskIdx,
|
| 13 |
+
" does not match the shape of the indexed tensor ", self.sizes(), " at index ", idx);
|
| 14 |
+
}
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
static C10_UNUSED std::vector<Tensor> expandTensors(const Tensor & self, IOptTensorListRef indices) {
|
| 18 |
+
// If indices come in as ByteTensor or BoolTensor (masks), expand them into the equivalent indexing by LongTensors
|
| 19 |
+
std::vector<Tensor> result;
|
| 20 |
+
for (const auto& index_opt : indices) {
|
| 21 |
+
if (!index_opt.has_value()) {
|
| 22 |
+
result.emplace_back();
|
| 23 |
+
} else {
|
| 24 |
+
const auto& index = *index_opt;
|
| 25 |
+
if (index.scalar_type() == kByte || index.scalar_type() == kBool) {
|
| 26 |
+
if (index.scalar_type() == kByte) {
|
| 27 |
+
TORCH_WARN("indexing with dtype torch.uint8 is now deprecated," \
|
| 28 |
+
" please use a dtype torch.bool instead.");
|
| 29 |
+
}
|
| 30 |
+
// The sizes of the ByteTensor mask or bool tensor must match the sizes of the
|
| 31 |
+
// corresponding dimensions in self
|
| 32 |
+
for (const auto j : c10::irange(index.dim())) {
|
| 33 |
+
int64_t srcIdx = static_cast<int64_t>(result.size() + j);
|
| 34 |
+
if (index.size(j) != self.size(srcIdx)) {
|
| 35 |
+
invalid_mask(self, srcIdx, index, j);
|
| 36 |
+
}
|
| 37 |
+
}
|
| 38 |
+
// Replace with nonzeros
|
| 39 |
+
auto nonzero = index.nonzero();
|
| 40 |
+
for (const auto j : c10::irange(index.dim())) {
|
| 41 |
+
result.emplace_back(nonzero.select(1, j));
|
| 42 |
+
}
|
| 43 |
+
} else {
|
| 44 |
+
result.emplace_back(index);
|
| 45 |
+
}
|
| 46 |
+
}
|
| 47 |
+
}
|
| 48 |
+
return result;
|
| 49 |
+
}
|
| 50 |
+
|
| 51 |
+
static C10_UNUSED void checkIndexTensorTypes(IOptTensorListRef indices, bool allow_int=false) {
|
| 52 |
+
for (const auto& tensor : indices) {
|
| 53 |
+
if (tensor.has_value() && tensor->defined()) {
|
| 54 |
+
auto scalarType = tensor->scalar_type();
|
| 55 |
+
if (allow_int) {
|
| 56 |
+
if (scalarType != kLong && scalarType != kByte && scalarType != kBool && scalarType != kInt) {
|
| 57 |
+
TORCH_CHECK_INDEX(false, "tensors used as indices must be long, int, byte or bool tensors");
|
| 58 |
+
}
|
| 59 |
+
} else {
|
| 60 |
+
if (scalarType != kLong && scalarType != kByte && scalarType != kBool) {
|
| 61 |
+
TORCH_CHECK_INDEX(false, "tensors used as indices must be long, byte or bool tensors");
|
| 62 |
+
}
|
| 63 |
+
}
|
| 64 |
+
}
|
| 65 |
+
}
|
| 66 |
+
}
|
| 67 |
+
|
| 68 |
+
inline torch::List<std::optional<Tensor>> toListOfOptionalTensors(ArrayRef<Tensor> list) {
|
| 69 |
+
torch::List<std::optional<Tensor>> result;
|
| 70 |
+
result.reserve(list.size());
|
| 71 |
+
for (const Tensor& a : list) {
|
| 72 |
+
result.push_back(a);
|
| 73 |
+
}
|
| 74 |
+
return result;
|
| 75 |
+
}
|
| 76 |
+
|
| 77 |
+
inline torch::List<std::optional<Tensor>> toListOfOptionalTensors(ArrayRef<IValue> list) {
|
| 78 |
+
torch::List<std::optional<Tensor>> result;
|
| 79 |
+
result.reserve(list.size());
|
| 80 |
+
for (const IValue& a : list) {
|
| 81 |
+
result.push_back(a.isTensor() ? std::optional<Tensor>(a.toTensor()) : std::optional<Tensor>());
|
| 82 |
+
}
|
| 83 |
+
return result;
|
| 84 |
+
}
|
| 85 |
+
|
| 86 |
+
static C10_UNUSED bool hasContiguousSubspace(TensorList tl) {
|
| 87 |
+
// true if all the non-null tensors are adjacent
|
| 88 |
+
auto isDefined = [](const Tensor & tensor){ return tensor.defined(); };
|
| 89 |
+
auto isNull = [](const Tensor & tensor){ return !tensor.defined(); };
|
| 90 |
+
auto start = std::find_if(tl.begin(), tl.end(), isDefined);
|
| 91 |
+
auto stop = std::find_if(tl.rbegin(), tl.rend(), isDefined);
|
| 92 |
+
auto it = std::find_if(start, stop.base(), isNull);
|
| 93 |
+
return it == stop.base();
|
| 94 |
+
}
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
// Transposes the tensor and indices together so that all the non-null indices
|
| 98 |
+
// index the first k dimensions of the tensor. Returns the transposed tensor
|
| 99 |
+
// and the reordered indices. For example:
|
| 100 |
+
// transposeToFront(tensor, {nullptr, a, nullptr, b})
|
| 101 |
+
// returns
|
| 102 |
+
// tensor.permute([1, 3, 0, 2]), {a, b, nullptr, nullptr}
|
| 103 |
+
static C10_UNUSED std::tuple<Tensor, std::vector<Tensor>>
|
| 104 |
+
transposeToFront(const Tensor& self, TensorList indices) {
|
| 105 |
+
std::vector<int64_t> dims;
|
| 106 |
+
std::vector<Tensor> transposedIndices;
|
| 107 |
+
dims.reserve(self.dim());
|
| 108 |
+
for (const auto i : c10::irange(self.dim())) {
|
| 109 |
+
if (indices[i].defined()) {
|
| 110 |
+
dims.push_back(i);
|
| 111 |
+
transposedIndices.emplace_back(indices[i]);
|
| 112 |
+
}
|
| 113 |
+
}
|
| 114 |
+
for (const auto i : c10::irange(self.dim())) {
|
| 115 |
+
if (!indices[i].defined()) {
|
| 116 |
+
dims.push_back(i);
|
| 117 |
+
transposedIndices.emplace_back();
|
| 118 |
+
}
|
| 119 |
+
}
|
| 120 |
+
return std::make_tuple(self.permute(dims), std::move(transposedIndices));
|
| 121 |
+
}
|
| 122 |
+
|
| 123 |
+
inline std::tuple<Tensor, std::vector<Tensor>, std::vector<int64_t>>
|
| 124 |
+
transposeToFrontAndInvPerm(const Tensor& self, TensorList indices) {
|
| 125 |
+
std::vector<int64_t> dims;
|
| 126 |
+
std::vector<int64_t> invPerm;
|
| 127 |
+
std::vector<Tensor> transposedIndices;
|
| 128 |
+
dims.reserve(self.dim());
|
| 129 |
+
invPerm.resize(self.dim());
|
| 130 |
+
for (const auto i : c10::irange(self.dim())) {
|
| 131 |
+
if (indices[i].defined()) {
|
| 132 |
+
dims.push_back(i);
|
| 133 |
+
transposedIndices.emplace_back(indices[i]);
|
| 134 |
+
}
|
| 135 |
+
}
|
| 136 |
+
for (const auto i : c10::irange(self.dim())) {
|
| 137 |
+
if (!indices[i].defined()) {
|
| 138 |
+
dims.push_back(i);
|
| 139 |
+
transposedIndices.emplace_back();
|
| 140 |
+
}
|
| 141 |
+
}
|
| 142 |
+
for (const auto i : c10::irange(self.dim())) {
|
| 143 |
+
invPerm[dims[i]] = i;
|
| 144 |
+
}
|
| 145 |
+
return std::make_tuple(self.permute(dims), std::move(transposedIndices), std::move(invPerm));
|
| 146 |
+
}
|
| 147 |
+
|
| 148 |
+
struct AdvancedIndex {
|
| 149 |
+
AdvancedIndex(const Tensor& src, TensorList indices);
|
| 150 |
+
|
| 151 |
+
Tensor src;
|
| 152 |
+
std::vector<Tensor> indices;
|
| 153 |
+
DimVector indexed_sizes;
|
| 154 |
+
DimVector indexed_strides;
|
| 155 |
+
int64_t dims_before;
|
| 156 |
+
int64_t dims_after;
|
| 157 |
+
};
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
} //namespace at::native
|
.venv/lib/python3.11/site-packages/torch/include/ATen/native/Lerp.h
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/native/DispatchStub.h>
|
| 4 |
+
#include <ATen/OpMathType.h>
|
| 5 |
+
#include <ATen/TensorIterator.h>
|
| 6 |
+
#include <c10/core/Scalar.h>
|
| 7 |
+
|
| 8 |
+
namespace at::native {
|
| 9 |
+
|
| 10 |
+
template <typename scalar_t>
|
| 11 |
+
C10_HOST_DEVICE C10_ALWAYS_INLINE bool is_lerp_weight_small(scalar_t weight) {
|
| 12 |
+
return std::abs(weight) < scalar_t(0.5);
|
| 13 |
+
}
|
| 14 |
+
template <typename scalar_t>
|
| 15 |
+
C10_HOST_DEVICE C10_ALWAYS_INLINE bool is_lerp_weight_small(c10::complex<scalar_t> weight) {
|
| 16 |
+
// Avoid the sqrt in abs(weight)
|
| 17 |
+
return (weight.real() * weight.real() + weight.imag() * weight.imag()) < scalar_t(0.25);
|
| 18 |
+
}
|
| 19 |
+
|
| 20 |
+
template <typename scalar_t, typename weight_t>
|
| 21 |
+
C10_HOST_DEVICE C10_ALWAYS_INLINE scalar_t lerp(scalar_t self_, scalar_t end_, weight_t weight_) {
|
| 22 |
+
using opmath_t = at::opmath_type<scalar_t>;
|
| 23 |
+
using opmath_weight_t = at::opmath_type<weight_t>;
|
| 24 |
+
|
| 25 |
+
opmath_t self = self_;
|
| 26 |
+
opmath_t end = end_;
|
| 27 |
+
opmath_weight_t weight = weight_;
|
| 28 |
+
|
| 29 |
+
// Conditional for better numeric. This has been discussed in
|
| 30 |
+
// https://github.com/pytorch/pytorch/pull/18871
|
| 31 |
+
return is_lerp_weight_small(weight)
|
| 32 |
+
? self + weight * (end - self)
|
| 33 |
+
: end - (end - self) * (opmath_t(1) - weight);
|
| 34 |
+
}
|
| 35 |
+
|
| 36 |
+
using lerp_fn_scalar = void (*)(
|
| 37 |
+
at::TensorIteratorBase& iter,
|
| 38 |
+
const Scalar& weight);
|
| 39 |
+
|
| 40 |
+
using lerp_fn_tensor = void (*)(
|
| 41 |
+
at::TensorIteratorBase& iter);
|
| 42 |
+
|
| 43 |
+
DECLARE_DISPATCH(lerp_fn_scalar, lerp_kernel_scalar_weight);
|
| 44 |
+
DECLARE_DISPATCH(lerp_fn_tensor, lerp_kernel_tensor_weight);
|
| 45 |
+
|
| 46 |
+
} // namespace at::native
|
.venv/lib/python3.11/site-packages/torch/include/ATen/native/LinearAlgebra.h
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/native/DispatchStub.h>
|
| 4 |
+
|
| 5 |
+
namespace c10 {
|
| 6 |
+
class Scalar;
|
| 7 |
+
}
|
| 8 |
+
|
| 9 |
+
namespace at {
|
| 10 |
+
struct TensorIterator;
|
| 11 |
+
}
|
| 12 |
+
|
| 13 |
+
namespace at::native {
|
| 14 |
+
|
| 15 |
+
using addr_fn = void (*)(TensorIterator &, const Scalar& beta, const Scalar& alpha);
|
| 16 |
+
DECLARE_DISPATCH(addr_fn, addr_stub);
|
| 17 |
+
} // namespace at::native
|
.venv/lib/python3.11/site-packages/torch/include/ATen/native/LinearAlgebraUtils.h
ADDED
|
@@ -0,0 +1,623 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <c10/core/ScalarType.h>
|
| 4 |
+
#include <c10/util/irange.h>
|
| 5 |
+
#include <c10/util/Exception.h>
|
| 6 |
+
#include <c10/util/strides.h>
|
| 7 |
+
#include <ATen/core/Tensor.h>
|
| 8 |
+
#include <ATen/ExpandUtils.h>
|
| 9 |
+
#include <ATen/TensorUtils.h>
|
| 10 |
+
#include <ATen/native/TensorIterator.h>
|
| 11 |
+
#include <ATen/native/TransposeType.h>
|
| 12 |
+
#include <limits>
|
| 13 |
+
#include <type_traits>
|
| 14 |
+
#include <sstream>
|
| 15 |
+
#include <cstring>
|
| 16 |
+
#include <cctype>
|
| 17 |
+
|
| 18 |
+
#ifndef AT_PER_OPERATOR_HEADERS
|
| 19 |
+
#include <ATen/Functions.h>
|
| 20 |
+
#else
|
| 21 |
+
#include <ATen/ops/arange.h>
|
| 22 |
+
#include <ATen/ops/empty.h>
|
| 23 |
+
#include <ATen/ops/empty_like.h>
|
| 24 |
+
#include <ATen/ops/empty_strided.h>
|
| 25 |
+
#include <ATen/ops/zeros.h>
|
| 26 |
+
#endif
|
| 27 |
+
|
| 28 |
+
namespace at::native {
|
| 29 |
+
|
| 30 |
+
inline c10::MaybeOwned<Tensor> expect_resolved_conj(const Tensor& tensor) {
|
| 31 |
+
if (tensor.is_conj()) {
|
| 32 |
+
return c10::MaybeOwned<Tensor>::owned(tensor.resolve_conj());
|
| 33 |
+
} else {
|
| 34 |
+
return c10::MaybeOwned<Tensor>::borrowed(tensor);
|
| 35 |
+
}
|
| 36 |
+
}
|
| 37 |
+
|
| 38 |
+
inline DimVector batched_matrix_contiguous_strides(
|
| 39 |
+
const IntArrayRef sizes,
|
| 40 |
+
const bool f_contig = false) {
|
| 41 |
+
// f_contig chooses between the strides of a batch of Fortran (F-contiguous)
|
| 42 |
+
// and C-contiguous matrices
|
| 43 |
+
auto strides = c10::contiguous_strides(sizes);
|
| 44 |
+
auto dim = strides.size();
|
| 45 |
+
|
| 46 |
+
if (f_contig && dim >= 2) {
|
| 47 |
+
// Fix the strides of the last two dimensions, so that we return
|
| 48 |
+
// C-contiguous batches of F-contiguous matrices.
|
| 49 |
+
strides[dim - 1] = std::max(sizes[dim - 2], static_cast<int64_t>(1));
|
| 50 |
+
strides[dim - 2] = 1;
|
| 51 |
+
}
|
| 52 |
+
return strides;
|
| 53 |
+
}
|
| 54 |
+
|
| 55 |
+
/*
|
| 56 |
+
* Clones a Tensor so that the following conditions hold:
|
| 57 |
+
* If we think of a Tensor of having size (B, M, N), where B is any number
|
| 58 |
+
* of batch dimensions, then:
|
| 59 |
+
* - Each (M, N) matrix is in column major form
|
| 60 |
+
* - Let Tensor P have size (B, M, N) and Q have size (B, M', N').
|
| 61 |
+
* Then when laid out in memory, the M by N matrix starting at
|
| 62 |
+
* P.data_ptr()[B * M * N] is of the same corresponding batch as the M' by N'
|
| 63 |
+
* matrix starting at Q.data_ptr()[B * M' * N'].
|
| 64 |
+
*/
|
| 65 |
+
inline Tensor cloneBatchedColumnMajor(const Tensor& src) {
|
| 66 |
+
// If src is already in batched column major format, then
|
| 67 |
+
// this will be efficient (no reordering of the data will occur)
|
| 68 |
+
// because the first transpose will make the tensor contiguous,
|
| 69 |
+
// and cloning a contiguous tensor is fast.
|
| 70 |
+
auto result = src.mT().clone(at::MemoryFormat::Contiguous);
|
| 71 |
+
result.transpose_(-2, -1);
|
| 72 |
+
return result;
|
| 73 |
+
}
|
| 74 |
+
|
| 75 |
+
/*
|
| 76 |
+
* contig chooses between C-contig (true) and F-contig (false)
|
| 77 |
+
*/
|
| 78 |
+
inline c10::MaybeOwned<Tensor> borrow_else_clone(const bool cond, const Tensor& borrow, const Tensor& clone, const bool contig) {
|
| 79 |
+
return cond ? c10::MaybeOwned<Tensor>::borrowed(borrow)
|
| 80 |
+
: c10::MaybeOwned<Tensor>::owned(contig ? clone.clone(MemoryFormat::Contiguous)
|
| 81 |
+
: cloneBatchedColumnMajor(clone));
|
| 82 |
+
}
|
| 83 |
+
|
| 84 |
+
/*
|
| 85 |
+
* This method is designed to be a faster alternative to
|
| 86 |
+
* `cloneBatchedColumnMajor` with some additional features,
|
| 87 |
+
* namely:
|
| 88 |
+
* 1. It uses `copy` instead of `clone` which could be much faster.
|
| 89 |
+
* 2. `nrows` parameter used to create inputs with the number of rows larger
|
| 90 |
+
* than the original input, which is required for some LAPACK/MAGMA methods.
|
| 91 |
+
* 3. `desired_batch_size` is used to create copies with the batch size
|
| 92 |
+
* which is either the original batch size of the input, or its larger
|
| 93 |
+
* broadcasted shape.
|
| 94 |
+
*/
|
| 95 |
+
inline Tensor copyBatchedColumnMajor(const Tensor& src, int64_t nrows = -1,
|
| 96 |
+
at::OptionalIntArrayRef desired_batch_sizes = std::nullopt) {
|
| 97 |
+
nrows = (nrows == -1) ? src.size(-2) : nrows;
|
| 98 |
+
auto copy_sizes = desired_batch_sizes.has_value()
|
| 99 |
+
? desired_batch_sizes.value().vec()
|
| 100 |
+
: IntArrayRef(src.sizes().data(), src.dim() - 2).vec();
|
| 101 |
+
copy_sizes.insert(copy_sizes.end(), {nrows, src.size(-1)});
|
| 102 |
+
const auto copy_strides = batched_matrix_contiguous_strides(copy_sizes, /*f-contig*/true);
|
| 103 |
+
auto copy = at::empty_strided(copy_sizes, copy_strides, src.options());
|
| 104 |
+
copy.narrow(-2, 0, src.size(-2)).copy_(src);
|
| 105 |
+
return copy;
|
| 106 |
+
}
|
| 107 |
+
|
| 108 |
+
/*
|
| 109 |
+
* Given batches of matrices with arbitrary batch dim,
|
| 110 |
+
* computes the number of batches.
|
| 111 |
+
*/
|
| 112 |
+
inline int64_t batchCount(const Tensor& batched_matrices) {
|
| 113 |
+
int64_t result = 1;
|
| 114 |
+
for (int64_t i = 0; i < batched_matrices.ndimension() - 2; i++) {
|
| 115 |
+
result *= batched_matrices.size(i);
|
| 116 |
+
}
|
| 117 |
+
return result;
|
| 118 |
+
}
|
| 119 |
+
|
| 120 |
+
// Computes the number of elements of a matrix in a batched matrix tensor
|
| 121 |
+
inline int64_t matrixStride(const Tensor& batched_matrices) {
|
| 122 |
+
return batched_matrices.size(-1) * batched_matrices.size(-2);
|
| 123 |
+
}
|
| 124 |
+
|
| 125 |
+
// Validates input shapes for operations on batches of square matrices (inverse, cholesky, symeig, eig)
|
| 126 |
+
inline void checkIsMatrix(const Tensor& A, const char* const f_name, const char* const arg_name = "A") {
|
| 127 |
+
TORCH_CHECK(A.dim() >= 2, f_name, ": The input tensor ", arg_name, " must have at least 2 dimensions.");
|
| 128 |
+
}
|
| 129 |
+
inline void squareCheckInputs(const Tensor& self, const char* const f_name, const char* const arg_name = "A") {
|
| 130 |
+
checkIsMatrix(self, f_name, arg_name);
|
| 131 |
+
TORCH_CHECK(self.sym_size(-1) == self.sym_size(-2),
|
| 132 |
+
f_name,
|
| 133 |
+
": ", arg_name, " must be batches of square matrices, "
|
| 134 |
+
"but they are ", self.sym_size(-2), " by ", self.sym_size(-1), " matrices");
|
| 135 |
+
}
|
| 136 |
+
|
| 137 |
+
inline void checkInputsSolver(const Tensor& A,
|
| 138 |
+
const Tensor& B,
|
| 139 |
+
const bool left,
|
| 140 |
+
const char* const f_name) {
|
| 141 |
+
squareCheckInputs(A, f_name, "A");
|
| 142 |
+
checkIsMatrix(B, f_name, "B");
|
| 143 |
+
TORCH_CHECK(left ? A.size(-2) == B.size(-2) : A.size(-1) == B.size(-1),
|
| 144 |
+
f_name, ": Incompatible shapes of A and B for the equation ",
|
| 145 |
+
left ? "AX = B" : "XA = B",
|
| 146 |
+
" (", A.size(-2), "x", A.size(-1), " and ", B.size(-2), "x", B.size(-1), ")");
|
| 147 |
+
}
|
| 148 |
+
|
| 149 |
+
inline bool is_row_or_column_contiguous(const Tensor& t) {
|
| 150 |
+
// This could be made more general, similar to how it's checked in matmul, which would allow to
|
| 151 |
+
// ellide the copy with strides such as (6, 12, 1, 3) or (3, 1, 9), but this is quite tricky.
|
| 152 |
+
// We choose to be conservative for simplicity
|
| 153 |
+
return t.is_contiguous() || t.transpose(-2, -1).is_contiguous();
|
| 154 |
+
}
|
| 155 |
+
|
| 156 |
+
inline TransposeType to_transpose_type(const bool contig, const bool conj) {
|
| 157 |
+
if (conj) {
|
| 158 |
+
if (contig) { TORCH_INTERNAL_ASSERT(false, "Invalid transpose type"); }
|
| 159 |
+
else { return TransposeType::ConjTranspose; }
|
| 160 |
+
} else {
|
| 161 |
+
if (contig) { return TransposeType::NoTranspose; }
|
| 162 |
+
else { return TransposeType::Transpose; }
|
| 163 |
+
}
|
| 164 |
+
}
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
// This function is designed to be used with linear algebra methods that minimize
|
| 168 |
+
// L(ax - b) = 0, where L is generally the identity map (`solve`, for example)
|
| 169 |
+
// or the L2 norm (`lstsq`).
|
| 170 |
+
// It is expected that `a` and `b` are contiguous tensors of column-major matrices
|
| 171 |
+
// (so that a.view({-1, a.size(-2), a.size(-1)}) succeeds, same for `b`),
|
| 172 |
+
// with the following additional properties:
|
| 173 |
+
//
|
| 174 |
+
// 1. a.dim() == b.dim()
|
| 175 |
+
// 2. a.shape[:-2] broadcasts over b.shape[:-2]
|
| 176 |
+
// 3. a.size(i) <= b.size(i) for i=0,..., a.dim() - 3 (only for batch dimensions)
|
| 177 |
+
//
|
| 178 |
+
// MAGMA/LAPACK modify tensor `a` in-place, and the main goal of this method
|
| 179 |
+
// is to be memory efficient, which means that if there exists an index i such that
|
| 180 |
+
// a.shape[i] < b.shape[i], 0 <= i <= a.dim() - 3,
|
| 181 |
+
// then instead of materializing copies of `a` in the broadcasted shape, we keep
|
| 182 |
+
// a buffer copy of `a` along with flags that check whether specific batch dimension
|
| 183 |
+
// indices for `a` were already accessed. If they were, we copy the data from the buffer
|
| 184 |
+
// into `a`. The number of copies does not exceed
|
| 185 |
+
// prod(max(a.shape[:-2], b.shape[:-2]) - a.shape[:-2] + 1)
|
| 186 |
+
// and this value is attained by tensors with non-empty batch dimensions.
|
| 187 |
+
//
|
| 188 |
+
// func_t `f` is a callable that is being supplied with
|
| 189 |
+
// scalar_t* a_working_ptr, scalar_t* b_working_ptr, int64_t a_linear_batch_idx.
|
| 190 |
+
// a_working_ptr and b_working_ptr can directly be passed to LAPACK/MAGMA routines,
|
| 191 |
+
// and a_linear_batch_idx is an index in the 3d representation which corresponds to
|
| 192 |
+
// the memory a_working_ptr points to, in other words:
|
| 193 |
+
// a_working_ptr == a.view({-1, a.size(-2), a.size(-1)}.select(0, a_linear_batch_idx).data_ptr<scalar_t>();
|
| 194 |
+
// a_linear_batch_idx is useful to store metadata related to `a`, such as, for example,
|
| 195 |
+
// its rank or singular values (see linalg_lstsq).
|
| 196 |
+
template<typename scalar_t, typename func_t>
|
| 197 |
+
void batch_iterator_with_broadcasting(const Tensor& a, const Tensor& b, const func_t& f) {
|
| 198 |
+
IntArrayRef a_batch_sizes(a.sizes().data(), a.dim() - 2);
|
| 199 |
+
IntArrayRef b_batch_sizes(b.sizes().data(), b.dim() - 2);
|
| 200 |
+
|
| 201 |
+
auto a_linear_batch_idx = at::arange(batchCount(a)).view(a_batch_sizes);
|
| 202 |
+
auto b_linear_batch_idx = at::arange(batchCount(b)).view(b_batch_sizes);
|
| 203 |
+
|
| 204 |
+
TensorIterator iter = TensorIteratorConfig()
|
| 205 |
+
.set_check_mem_overlap(false)
|
| 206 |
+
.check_all_same_dtype(false)
|
| 207 |
+
.resize_outputs(false)
|
| 208 |
+
.add_output(b_linear_batch_idx)
|
| 209 |
+
.add_input(a_linear_batch_idx)
|
| 210 |
+
.build();
|
| 211 |
+
|
| 212 |
+
auto m = a.size(-2);
|
| 213 |
+
auto n = a.size(-1);
|
| 214 |
+
auto a_3d = a.view({batchCount(a), m, n});
|
| 215 |
+
auto b_3d = b.view({batchCount(b), b.size(-2), b.size(-1)});
|
| 216 |
+
|
| 217 |
+
auto a_broadcasts_over_b = (a_batch_sizes != b_batch_sizes);
|
| 218 |
+
Tensor a_buffer, a_was_accessed, a_buffer_3d;
|
| 219 |
+
std::function<void(int64_t)> check_if_copy_needed_for_a
|
| 220 |
+
= [](int64_t /*a_curr_linear_batch_idx*/){};
|
| 221 |
+
if (a_broadcasts_over_b) {
|
| 222 |
+
a_buffer = at::empty_strided(a.sizes(), a.strides(), a.options())
|
| 223 |
+
.copy_(a);
|
| 224 |
+
a_was_accessed = at::zeros(batchCount(a), at::kBool);
|
| 225 |
+
a_buffer_3d = a_buffer.view({batchCount(a), m, n});
|
| 226 |
+
check_if_copy_needed_for_a = [&](int64_t a_curr_linear_batch_idx) {
|
| 227 |
+
auto* a_was_accessed_flag = a_was_accessed
|
| 228 |
+
.select(0, a_curr_linear_batch_idx)
|
| 229 |
+
.data_ptr<bool>();
|
| 230 |
+
if (!(*a_was_accessed_flag)) {
|
| 231 |
+
*a_was_accessed_flag = true;
|
| 232 |
+
}
|
| 233 |
+
else {
|
| 234 |
+
a_3d.select(0, a_curr_linear_batch_idx)
|
| 235 |
+
.copy_(a_buffer_3d.select(0, a_curr_linear_batch_idx));
|
| 236 |
+
}
|
| 237 |
+
};
|
| 238 |
+
}
|
| 239 |
+
|
| 240 |
+
auto loop = [&](char** data, const int64_t* strides, int64_t nelems) {
|
| 241 |
+
auto* b_batch_idx_ptr = data[0];
|
| 242 |
+
auto* a_batch_idx_ptr = data[1];
|
| 243 |
+
|
| 244 |
+
for (const auto elem C10_UNUSED : c10::irange(nelems)) {
|
| 245 |
+
auto b_curr_linear_batch_idx = *reinterpret_cast<int64_t*>(b_batch_idx_ptr);
|
| 246 |
+
auto a_curr_linear_batch_idx = *reinterpret_cast<int64_t*>(a_batch_idx_ptr);
|
| 247 |
+
|
| 248 |
+
check_if_copy_needed_for_a(a_curr_linear_batch_idx);
|
| 249 |
+
|
| 250 |
+
auto* a_working_ptr = a_3d.select(0, a_curr_linear_batch_idx)
|
| 251 |
+
.data_ptr<scalar_t>();
|
| 252 |
+
auto* b_working_ptr = b_3d.select(0, b_curr_linear_batch_idx)
|
| 253 |
+
.data_ptr<scalar_t>();
|
| 254 |
+
f(a_working_ptr, b_working_ptr, a_curr_linear_batch_idx);
|
| 255 |
+
|
| 256 |
+
b_batch_idx_ptr += strides[0];
|
| 257 |
+
a_batch_idx_ptr += strides[1];
|
| 258 |
+
}
|
| 259 |
+
};
|
| 260 |
+
iter.serial_for_each(loop, {0, batchCount(b)});
|
| 261 |
+
}
|
| 262 |
+
|
| 263 |
+
// Returns the epsilon value for floating types except half
|
| 264 |
+
inline double _get_epsilon(const ScalarType& sc_type) {
|
| 265 |
+
switch (sc_type) {
|
| 266 |
+
case at::ScalarType::Float:
|
| 267 |
+
return static_cast<double>(std::numeric_limits<float>::epsilon());
|
| 268 |
+
case at::ScalarType::Double:
|
| 269 |
+
return std::numeric_limits<double>::epsilon();
|
| 270 |
+
default:
|
| 271 |
+
AT_ERROR("This function doesn't handle types other than float and double");
|
| 272 |
+
}
|
| 273 |
+
}
|
| 274 |
+
|
| 275 |
+
// Validates input shapes and devices
|
| 276 |
+
// for linear solve methods (solve, cholesky_solve, lu_solve, triangular_solve)
|
| 277 |
+
inline void linearSolveCheckInputs(const Tensor& self, const Tensor& A, const char* name) {
|
| 278 |
+
TORCH_CHECK(self.device() == A.device(),
|
| 279 |
+
"Expected b and A to be on the same device, but found b on ",
|
| 280 |
+
self.device(), " and A on ", A.device(), " instead.");
|
| 281 |
+
|
| 282 |
+
TORCH_CHECK(self.scalar_type() == A.scalar_type(),
|
| 283 |
+
"Expected b and A to have the same dtype, but found b of type ",
|
| 284 |
+
self.scalar_type(), " and A of type ", A.scalar_type(), " instead.");
|
| 285 |
+
|
| 286 |
+
TORCH_CHECK(A.size(-1) == A.size(-2),
|
| 287 |
+
"A must be batches of square matrices, "
|
| 288 |
+
"but they are ", A.size(-2), " by ", A.size(-1), " matrices");
|
| 289 |
+
|
| 290 |
+
TORCH_CHECK(A.size(-1) == self.size(-2),
|
| 291 |
+
"Incompatible matrix sizes for ", name, ": each A "
|
| 292 |
+
"matrix is ", A.size(-1), " by ", A.size(-1),
|
| 293 |
+
" but each b matrix is ", self.size(-2), " by ", self.size(-1));
|
| 294 |
+
}
|
| 295 |
+
|
| 296 |
+
inline void checkFloatingOrComplex(const Tensor& t, const char* const f_name, const bool allow_low_precision_dtypes=true) {
|
| 297 |
+
auto dtype = t.scalar_type();
|
| 298 |
+
TORCH_CHECK((at::isFloatingType(dtype) || at::isComplexType(dtype)),
|
| 299 |
+
f_name, ": Expected a floating point or complex tensor as input. Got ", dtype);
|
| 300 |
+
if (!allow_low_precision_dtypes) {
|
| 301 |
+
TORCH_CHECK(dtype == kFloat || dtype == kDouble || dtype == kComplexFloat || dtype == kComplexDouble,
|
| 302 |
+
f_name, ": Low precision dtypes not supported. Got ", dtype);
|
| 303 |
+
}
|
| 304 |
+
}
|
| 305 |
+
|
| 306 |
+
|
| 307 |
+
// Checks if all the Tensors in a TensorList are of the same dimensions
|
| 308 |
+
inline void checkAllSameDim(TensorList tensors, int64_t dim) {
|
| 309 |
+
for (auto &t : tensors) {
|
| 310 |
+
TORCH_CHECK(t.dim() == dim, "Tensor dimension is ", t.dim(), ", expected ", dim, " instead.");
|
| 311 |
+
}
|
| 312 |
+
}
|
| 313 |
+
|
| 314 |
+
inline std::tuple<std::vector<int64_t>, std::vector<int64_t>> _linalg_broadcast_batch_dims(const Tensor& arg1, const Tensor& arg2) {
|
| 315 |
+
// broadcast the batch dimensions of arg1 and arg2.
|
| 316 |
+
IntArrayRef arg1_batch_sizes(arg1.sizes().data(), arg1.ndimension() - 2);
|
| 317 |
+
IntArrayRef arg2_batch_sizes(arg2.sizes().data(), arg2.ndimension() - 2);
|
| 318 |
+
std::vector<int64_t> expand_batch_portion = infer_size(arg1_batch_sizes, arg2_batch_sizes);
|
| 319 |
+
|
| 320 |
+
std::vector<int64_t> arg1_expand_size({expand_batch_portion});
|
| 321 |
+
arg1_expand_size.insert(arg1_expand_size.end(), { arg1.size(-2), arg1.size(-1) });
|
| 322 |
+
|
| 323 |
+
std::vector<int64_t> arg2_expand_size({expand_batch_portion});
|
| 324 |
+
arg2_expand_size.insert(arg2_expand_size.end(), { arg2.size(-2), arg2.size(-1) });
|
| 325 |
+
return std::make_tuple(std::move(arg1_expand_size), std::move(arg2_expand_size));
|
| 326 |
+
}
|
| 327 |
+
|
| 328 |
+
inline std::tuple<Tensor,Tensor> _linalg_broadcast_batch_dims(const Tensor& arg1, const Tensor& arg2, const char* name) {
|
| 329 |
+
// If there's no name we assume we don't want to check the errors
|
| 330 |
+
if (name != nullptr) {
|
| 331 |
+
linearSolveCheckInputs(arg1, arg2, name);
|
| 332 |
+
}
|
| 333 |
+
|
| 334 |
+
auto [arg1_expand_size, arg2_expand_size] = at::native::_linalg_broadcast_batch_dims(arg1, arg2);
|
| 335 |
+
|
| 336 |
+
auto arg1_broadcasted = arg1_expand_size == arg1.sizes() ? arg1 : arg1.expand(arg1_expand_size);
|
| 337 |
+
auto arg2_broadcasted = arg2_expand_size == arg2.sizes() ? arg2 : arg2.expand(arg2_expand_size);
|
| 338 |
+
return std::make_tuple(arg1_broadcasted, arg2_broadcasted);
|
| 339 |
+
}
|
| 340 |
+
|
| 341 |
+
inline std::vector<int64_t> broadcast_batch_size(const Tensor& t1, const Tensor& t2, int64_t n_batch_dims) {
|
| 342 |
+
IntArrayRef t1_batch_sizes(t1.sizes().data(), n_batch_dims);
|
| 343 |
+
IntArrayRef t2_batch_sizes(t2.sizes().data(), n_batch_dims);
|
| 344 |
+
auto broadcasted_batch_sizes = infer_size(t1_batch_sizes, t2_batch_sizes);
|
| 345 |
+
return broadcasted_batch_sizes;
|
| 346 |
+
}
|
| 347 |
+
|
| 348 |
+
// Return a permutation with the given axes moved to the end.
|
| 349 |
+
inline Tensor _move_to_end(const Tensor& self, IntArrayRef axes) {
|
| 350 |
+
const std::vector<int64_t> a = axes.vec();
|
| 351 |
+
const int64_t ndim = self.ndimension();
|
| 352 |
+
std::vector<int64_t> perm;
|
| 353 |
+
|
| 354 |
+
for (const auto i : c10::irange(ndim)) {
|
| 355 |
+
auto it = std::find(a.begin(), a.end(), i);
|
| 356 |
+
if (it == a.end()) {
|
| 357 |
+
perm.push_back(i);
|
| 358 |
+
}
|
| 359 |
+
}
|
| 360 |
+
for (auto i : a) {
|
| 361 |
+
perm.push_back(i);
|
| 362 |
+
}
|
| 363 |
+
|
| 364 |
+
TORCH_CHECK((int64_t)perm.size() == ndim,
|
| 365 |
+
"duplicate or invalid axis in 'dim' argument for tensor with ndim==", ndim);
|
| 366 |
+
|
| 367 |
+
return self.permute(perm);
|
| 368 |
+
}
|
| 369 |
+
|
| 370 |
+
// parse the "mode" param in linalg_qr: return a tuple of bools (compute_q, reduced)
|
| 371 |
+
inline std::tuple<bool, bool> _parse_qr_mode(c10::string_view mode) {
|
| 372 |
+
bool compute_q;
|
| 373 |
+
bool reduced;
|
| 374 |
+
if (mode == "reduced") {
|
| 375 |
+
compute_q = true;
|
| 376 |
+
reduced = true;
|
| 377 |
+
} else if (mode == "complete") {
|
| 378 |
+
compute_q = true;
|
| 379 |
+
reduced = false;
|
| 380 |
+
} else if (mode == "r") {
|
| 381 |
+
compute_q = false;
|
| 382 |
+
reduced = true; // this is actually irrelevant in this mode
|
| 383 |
+
} else {
|
| 384 |
+
TORCH_CHECK(false, "qr received unrecognized mode '", mode,
|
| 385 |
+
"' but expected one of 'reduced' (default), 'r', or 'complete'");
|
| 386 |
+
}
|
| 387 |
+
return std::make_tuple(compute_q, reduced);
|
| 388 |
+
}
|
| 389 |
+
|
| 390 |
+
// Function to compute sizes, strides and the extra columns for the Q matrix in the QR Decomposition
|
| 391 |
+
inline std::tuple<DimVector, DimVector, int64_t> _compute_geometry_for_Q(
|
| 392 |
+
const Tensor& input,
|
| 393 |
+
bool reduced) {
|
| 394 |
+
int64_t m = input.size(-2), n = input.size(-1);
|
| 395 |
+
int64_t n_columns_q;
|
| 396 |
+
|
| 397 |
+
// We need to compute the required size of Q based on the `reduced` option
|
| 398 |
+
DimVector q_sizes(input.sizes());
|
| 399 |
+
if (!reduced && m > n) {
|
| 400 |
+
q_sizes[input.dim() - 1] = m;
|
| 401 |
+
n_columns_q = m;
|
| 402 |
+
} else {
|
| 403 |
+
q_sizes[input.dim() - 1] = n;
|
| 404 |
+
n_columns_q = std::min(m, n);
|
| 405 |
+
}
|
| 406 |
+
auto q_strides = batched_matrix_contiguous_strides(q_sizes, /*f-contig*/true);
|
| 407 |
+
return std::make_tuple(q_sizes, q_strides, n_columns_q);
|
| 408 |
+
}
|
| 409 |
+
|
| 410 |
+
inline bool svd_uses_cusolver(const Tensor& A) {
|
| 411 |
+
// if cusolver is available, it is used unconditionally
|
| 412 |
+
return A.is_cuda()
|
| 413 |
+
&& at::globalContext().hasCuSOLVER()
|
| 414 |
+
&& at::globalContext().linalgPreferredBackend() != at::LinalgBackend::Magma;
|
| 415 |
+
}
|
| 416 |
+
|
| 417 |
+
|
| 418 |
+
// Function used instead of .to so that the original strides are retained
|
| 419 |
+
// .to doesn't retain strides and make the output tensor contiguous
|
| 420 |
+
inline Tensor same_stride_to(const Tensor& original_tensor, const at::TensorOptions& options) {
|
| 421 |
+
auto strided_to = at::empty_strided(original_tensor.sizes(),
|
| 422 |
+
original_tensor.strides(),
|
| 423 |
+
options);
|
| 424 |
+
strided_to.copy_(original_tensor);
|
| 425 |
+
return strided_to;
|
| 426 |
+
}
|
| 427 |
+
|
| 428 |
+
// Creates a dimension permutation array that can be given to `at::permute()`, which will shift
|
| 429 |
+
// the two specified dimensions to the end of a tensor, without changing the order of
|
| 430 |
+
// the other dimensions. `dim1` will be placed at the very end, and `dim0` will be
|
| 431 |
+
// placed just to the left of it.
|
| 432 |
+
//
|
| 433 |
+
// For instance, given a 4-D tensor, dimensions 1 and 3 can be shifted to the end by
|
| 434 |
+
// calling `create_dim_backshift_permutation(1, 3, 4)`. The resulting vector will
|
| 435 |
+
// be `vec(0, 2, 1, 3)`.
|
| 436 |
+
inline std::vector<int64_t> create_dim_backshift_permutation(int64_t dim0, int64_t dim1, int64_t ndim) {
|
| 437 |
+
TORCH_CHECK(
|
| 438 |
+
(dim0 != dim1) && (dim0 < ndim) && (dim0 >= 0) && (dim1 < ndim) && (dim1 >= 0),
|
| 439 |
+
"duplicate or invalid dimensions");
|
| 440 |
+
std::vector<int64_t> permutation(ndim);
|
| 441 |
+
int64_t cur_permuted_dim = 0;
|
| 442 |
+
for (const auto dim_ind : c10::irange(ndim)) {
|
| 443 |
+
if ((dim_ind != dim0) && (dim_ind != dim1)) {
|
| 444 |
+
permutation[cur_permuted_dim++] = dim_ind;
|
| 445 |
+
}
|
| 446 |
+
}
|
| 447 |
+
permutation[cur_permuted_dim++] = dim0;
|
| 448 |
+
permutation[cur_permuted_dim] = dim1;
|
| 449 |
+
return permutation;
|
| 450 |
+
}
|
| 451 |
+
|
| 452 |
+
// Creates a dimension permutation array that can be given to `at::permute()`, which
|
| 453 |
+
// will reverse a given permutation.
|
| 454 |
+
// The reverse permutation array is created by swapping the indices and their
|
| 455 |
+
// associated values from the given permutation array.
|
| 456 |
+
inline std::vector<int64_t> create_reverse_permutation(std::vector<int64_t> permutation) {
|
| 457 |
+
int64_t ndim = permutation.size();
|
| 458 |
+
std::vector<int64_t> reverse_permutation(ndim);
|
| 459 |
+
for (const auto dim_ind : c10::irange(ndim)) {
|
| 460 |
+
reverse_permutation[permutation[dim_ind]] = dim_ind;
|
| 461 |
+
}
|
| 462 |
+
return reverse_permutation;
|
| 463 |
+
}
|
| 464 |
+
|
| 465 |
+
// Compute R-work array size for MAGMA/LAPACK cgesdd/zgesdd
|
| 466 |
+
// See https://github.com/Reference-LAPACK/lapack/blob/122506cd8b6ce050a200920c3d4c0b153b150fd8/SRC/cgesdd.f#L186
|
| 467 |
+
inline int64_t computeLRWorkDim(const char jobz, int64_t m, int64_t n) {
|
| 468 |
+
auto mn = std::min(m, n);
|
| 469 |
+
auto mx = std::max(m, n);
|
| 470 |
+
if (jobz == 'N') {
|
| 471 |
+
#ifdef __APPLE__
|
| 472 |
+
// According to `vecLib.framework/Headers/clapack.h` Accelerate.framework is based on LAPACK 3.2.1
|
| 473 |
+
return 7 * mn;
|
| 474 |
+
#else
|
| 475 |
+
// These setting is valid for on LAPACK 3.6+
|
| 476 |
+
return 5 * mn;
|
| 477 |
+
#endif
|
| 478 |
+
}
|
| 479 |
+
if (mx > 10 * mn) {
|
| 480 |
+
return 5 * mn * mn + 5 * mn;
|
| 481 |
+
}
|
| 482 |
+
return std::max(5 * mn * mn + 5 * mn, 2 * mx * mn + 2 * mn * mn + mn);
|
| 483 |
+
}
|
| 484 |
+
|
| 485 |
+
// This function checks whether the uplo argument input is valid
|
| 486 |
+
// Allowed strings are "u", "U", "l", "L"
|
| 487 |
+
inline void checkUplo(const c10::string_view uplo) {
|
| 488 |
+
// To use std::toupper safely with plain chars (or signed chars), the argument should first be converted to unsigned char
|
| 489 |
+
char uplo_uppercase = static_cast<char>(std::toupper(static_cast<unsigned char>(uplo[0])));
|
| 490 |
+
TORCH_CHECK(uplo.size() == 1 && (uplo_uppercase == 'U' || uplo_uppercase == 'L'),
|
| 491 |
+
"Expected UPLO argument to be 'L' or 'U', but got ", uplo);
|
| 492 |
+
}
|
| 493 |
+
|
| 494 |
+
inline void checkSameDevice(const std::string& fn_name, Tensor result, Tensor input, const std::string& result_name = "result") {
|
| 495 |
+
TORCH_CHECK(
|
| 496 |
+
result.device() == input.device(),
|
| 497 |
+
fn_name,
|
| 498 |
+
": Expected ", result_name, " and input tensors to be on the same device, but got ",
|
| 499 |
+
result_name, " on ", result.device(), " and input on ", input.device());
|
| 500 |
+
}
|
| 501 |
+
|
| 502 |
+
// Check the dtype of result and input tensors (for _out variants).
|
| 503 |
+
// Most linear algebra functions have the same dtype for input and output
|
| 504 |
+
// (either floating or complex type input), so we can check whether input's dtype can be casted to result's dtype.
|
| 505 |
+
// According to https://github.com/pytorch/pytorch/wiki/Developer-FAQ#how-does-out-work-in-pytorch
|
| 506 |
+
// c10::canCast is used for checking the "safe copy" dtype requirements.
|
| 507 |
+
inline void checkLinalgCompatibleDtype(const std::string& fn_name, Tensor result, Tensor input, const std::string& result_name = "result") {
|
| 508 |
+
bool can_cast = c10::canCast(input.scalar_type(), result.scalar_type());
|
| 509 |
+
TORCH_CHECK(
|
| 510 |
+
can_cast,
|
| 511 |
+
fn_name,
|
| 512 |
+
": Expected ", result_name, " to be safely castable from ", input.scalar_type(), " dtype, but got ",
|
| 513 |
+
result_name, " with dtype ", result.scalar_type());
|
| 514 |
+
}
|
| 515 |
+
|
| 516 |
+
// Alternatively, we can check whether the specific expected output type (result_type) can be safely casted to out tensor dtype (out_type)
|
| 517 |
+
inline void checkLinalgCompatibleDtype(const std::string& fn_name, ScalarType out_type, ScalarType result_type, const std::string& out_name = "result") {
|
| 518 |
+
bool can_cast = c10::canCast(result_type, out_type);
|
| 519 |
+
TORCH_CHECK(
|
| 520 |
+
can_cast,
|
| 521 |
+
fn_name,
|
| 522 |
+
": Expected ", out_name, " to be safely castable from ", result_type, " dtype, but got ",
|
| 523 |
+
out_name, " with dtype ", out_type);
|
| 524 |
+
}
|
| 525 |
+
|
| 526 |
+
inline void checkNotComplexTolerance(const Tensor& tol, const c10::string_view f_name, const c10::string_view tol_name) {
|
| 527 |
+
TORCH_CHECK(!at::isComplexType(tol.scalar_type()),
|
| 528 |
+
f_name, ": ", tol_name, " tensor of complex type is not supported. Got ", tol.scalar_type());
|
| 529 |
+
}
|
| 530 |
+
|
| 531 |
+
/*
|
| 532 |
+
Two types of 'other' tensors are supported when solving
|
| 533 |
+
a system of linear equations matmul(input, x) = other:
|
| 534 |
+
* 1-dimensional (1D) tensor or batch of 1D tensors (vector case)
|
| 535 |
+
* 2-dimensional (2D) tensor or batch of 2D tensors (matrix case).
|
| 536 |
+
The original torch.solve supported only the matrix case, while NumPy works for both cases.
|
| 537 |
+
For the batched input we need to be able to distinguish them.
|
| 538 |
+
Let input.shape = (batch_dimensions, m, n), then 'other' is of vector type if other.shape == (batch_dimensions, m).
|
| 539 |
+
This rule is compatible with NumPy, see https://github.com/numpy/numpy/blob/v1.20.0/numpy/linalg/linalg.py#L384-L389
|
| 540 |
+
*/
|
| 541 |
+
inline bool linalg_solve_is_vector_rhs(const Tensor& input, const Tensor& other) {
|
| 542 |
+
auto expected_batched_rhs_shape = SymIntArrayRef(input.sym_sizes().data(), input.dim() - 1); // input.shape[:-1]
|
| 543 |
+
bool vector_case = other.dim() == 1 || (input.dim() - 1 == other.dim() && other.sym_sizes().equals(expected_batched_rhs_shape));
|
| 544 |
+
return vector_case;
|
| 545 |
+
}
|
| 546 |
+
|
| 547 |
+
/*
|
| 548 |
+
Computes linear indices for a tensor with original_shape to access its elements like it was a materialized broadcast tensor.
|
| 549 |
+
*/
|
| 550 |
+
inline Tensor get_linear_indices(int64_t numel, IntArrayRef original_shape, IntArrayRef broadcast_shape) {
|
| 551 |
+
TensorOptions options = at::TensorOptions().dtype(at::kLong).device(at::kCPU);
|
| 552 |
+
return at::arange(numel, options).view(original_shape).broadcast_to(broadcast_shape).contiguous();
|
| 553 |
+
}
|
| 554 |
+
|
| 555 |
+
class BroadcastLinearIndices {
|
| 556 |
+
private:
|
| 557 |
+
Tensor linear_indices_;
|
| 558 |
+
bool is_broadcasting_;
|
| 559 |
+
|
| 560 |
+
public:
|
| 561 |
+
BroadcastLinearIndices(
|
| 562 |
+
int64_t numel,
|
| 563 |
+
IntArrayRef original_shape,
|
| 564 |
+
IntArrayRef broadcast_shape) : is_broadcasting_(!original_shape.equals(broadcast_shape)) {
|
| 565 |
+
// The assumption is that the broadcast_shape is a materialized broadcast
|
| 566 |
+
// shape of the original_shape. We need to compute the linear indices
|
| 567 |
+
// compatible with the original_shape to access the elements in the original
|
| 568 |
+
// tensor corresponding to the broadcast tensor.
|
| 569 |
+
if (is_broadcasting_) {
|
| 570 |
+
linear_indices_ =
|
| 571 |
+
get_linear_indices(numel, original_shape, broadcast_shape);
|
| 572 |
+
}
|
| 573 |
+
}
|
| 574 |
+
int64_t operator()(int64_t broadcast_linear_index) {
|
| 575 |
+
return is_broadcasting_
|
| 576 |
+
? linear_indices_.data_ptr<int64_t>()[broadcast_linear_index]
|
| 577 |
+
: broadcast_linear_index;
|
| 578 |
+
}
|
| 579 |
+
};
|
| 580 |
+
|
| 581 |
+
inline bool is_blas_compatible_column_major_order(const Tensor& input) {
|
| 582 |
+
IntArrayRef input_strides = input.strides();
|
| 583 |
+
IntArrayRef input_sizes = input.sizes();
|
| 584 |
+
auto ndim = input.dim();
|
| 585 |
+
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(ndim >= 2);
|
| 586 |
+
if (ndim > 3) {
|
| 587 |
+
return input.transpose(-2, -1).is_contiguous();
|
| 588 |
+
}
|
| 589 |
+
auto leading_dimension = input_strides[ndim - 1];
|
| 590 |
+
auto rows = input_sizes[ndim - 2];
|
| 591 |
+
bool batch_stride_compatible = true;
|
| 592 |
+
if (ndim == 3) {
|
| 593 |
+
auto cols = input_sizes[ndim - 1];
|
| 594 |
+
batch_stride_compatible =
|
| 595 |
+
input_strides[ndim - 3] >= leading_dimension * cols;
|
| 596 |
+
}
|
| 597 |
+
return (input_strides[ndim - 2] == 1) &&
|
| 598 |
+
(leading_dimension >= std::max<int64_t>(1, rows)) &&
|
| 599 |
+
batch_stride_compatible;
|
| 600 |
+
}
|
| 601 |
+
|
| 602 |
+
inline bool is_blas_compatible_row_major_order(const Tensor& input) {
|
| 603 |
+
IntArrayRef input_strides = input.strides();
|
| 604 |
+
IntArrayRef input_sizes = input.sizes();
|
| 605 |
+
auto ndim = input.dim();
|
| 606 |
+
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(ndim >= 2);
|
| 607 |
+
if (ndim > 3) {
|
| 608 |
+
return input.is_contiguous();
|
| 609 |
+
}
|
| 610 |
+
auto leading_dimension = input_strides[ndim - 2];
|
| 611 |
+
auto cols = input_sizes[ndim - 1];
|
| 612 |
+
bool batch_stride_compatible = true;
|
| 613 |
+
if (ndim == 3) {
|
| 614 |
+
auto rows = input_sizes[ndim - 2];
|
| 615 |
+
batch_stride_compatible =
|
| 616 |
+
input_strides[ndim - 3] >= leading_dimension * rows;
|
| 617 |
+
}
|
| 618 |
+
return (input_strides[ndim - 1] == 1) &&
|
| 619 |
+
(leading_dimension >= std::max<int64_t>(1, cols)) &&
|
| 620 |
+
batch_stride_compatible;
|
| 621 |
+
}
|
| 622 |
+
|
| 623 |
+
} // namespace at::native
|
.venv/lib/python3.11/site-packages/torch/include/ATen/native/MathBitsFallback.h
ADDED
|
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <ATen/core/Tensor.h>
|
| 2 |
+
#include <ATen/core/dispatch/Dispatcher.h>
|
| 3 |
+
#include <ATen/core/op_registration/op_registration.h>
|
| 4 |
+
#include <ATen/native/UnaryOps.h>
|
| 5 |
+
#include <ATen/native/Resize.h>
|
| 6 |
+
#include <c10/util/irange.h>
|
| 7 |
+
#include <torch/library.h>
|
| 8 |
+
|
| 9 |
+
#ifndef AT_PER_OPERATOR_HEADERS
|
| 10 |
+
#include <ATen/Functions.h>
|
| 11 |
+
#else
|
| 12 |
+
#include <ATen/ops/clone.h>
|
| 13 |
+
|
| 14 |
+
#include <utility>
|
| 15 |
+
#endif
|
| 16 |
+
|
| 17 |
+
namespace at::native {
|
| 18 |
+
// This fallback should only be used for operations that are self inverse and have a corresponding tensor
|
| 19 |
+
// bit (internally implemented using DispatchKey) to maintain the state on tensor using tensor bit.
|
| 20 |
+
// Currently there are two tensor bits that trigger this fallback: conjugate bit and negative bit.
|
| 21 |
+
// Conjugate bit is set on a tensor when `.conj()` is called and neg bit is set on a tensor when `.conj().imag` is called.
|
| 22 |
+
|
| 23 |
+
// NOTE: To use this fallback, `clone` and `copy_` should fully understand and be able to correctly handle the semantic of your math bit.
|
| 24 |
+
struct MathOpFallback {
|
| 25 |
+
MathOpFallback(DispatchKey key_, string op_name_) : key(key_), op_name(std::move(op_name_)) {}
|
| 26 |
+
virtual bool is_bit_set(const Tensor&) = 0;
|
| 27 |
+
void fallback_impl(const c10::OperatorHandle& op, DispatchKeySet dispatch_keys, torch::jit::Stack* stack) {
|
| 28 |
+
/*
|
| 29 |
+
Situations to handle:
|
| 30 |
+
1. Out-of-place operation. Easy: materialize all inputs and
|
| 31 |
+
call it a day.
|
| 32 |
+
2. Inplace operation. Desugar x.add_(2) into x.conj_().add_(2).conj_().
|
| 33 |
+
Materialize other inputs as in (1).
|
| 34 |
+
3. out= operation. Desugar add(x, 2, out=y) into y.copy_(add(x, 2))
|
| 35 |
+
Materialize other inputs as in (1).
|
| 36 |
+
|
| 37 |
+
It is important to be able to tell if we READ from an argument and if we
|
| 38 |
+
WRITE to an argument. Conservative approach is to assume that we always
|
| 39 |
+
READ from an argument, but in out= operations you can skip
|
| 40 |
+
conjugating inputs on entry that never get used. In the current schema we
|
| 41 |
+
can't easily tell if the operation is in in-place or out= operation.
|
| 42 |
+
|
| 43 |
+
Note:
|
| 44 |
+
1. Mutable tensorlists containing tensors whose math bit set to true are disallowed.
|
| 45 |
+
2. Mutable tensors with math bit set to true are unconditionally cloned to ensure
|
| 46 |
+
correct behavior in the case when the mutable tensor shares memory with non mutable arguments.
|
| 47 |
+
|
| 48 |
+
If we were to in-place resolve the math bit for mutable inputs, then the non-mutable inputs sharing partial or full memory
|
| 49 |
+
with these mutable inputs would read into wrong values in the following cases:
|
| 50 |
+
1. Non mutable inputs have their math bit set to false.
|
| 51 |
+
2. Math bit for mutable input(s) is resolved before the non mutable inputs (with bit set to true and sharing memory
|
| 52 |
+
with one or more mutable arg(s)) are cloned.
|
| 53 |
+
At the end, the final value of the mutable arguments from the stack are copied into the original input mutable tensor inputs.
|
| 54 |
+
*/
|
| 55 |
+
const auto& arguments = op.schema().arguments();
|
| 56 |
+
const auto num_arguments = arguments.size();
|
| 57 |
+
const auto stack_start = stack->size() - num_arguments;
|
| 58 |
+
|
| 59 |
+
std::optional<bool> is_write;
|
| 60 |
+
for (const auto i : c10::irange(num_arguments)) {
|
| 61 |
+
// Three possible states:
|
| 62 |
+
// 1. alias_info has no value --> out-of-place operation
|
| 63 |
+
// 2. alias_info does have a value, alias_info->is_write=True --> in-place or out= operation
|
| 64 |
+
// 3. alias_info does have a value, alias_info->is_write=False --> view operation
|
| 65 |
+
const AliasInfo* alias_info = arguments[i].alias_info();
|
| 66 |
+
if (alias_info != nullptr) {
|
| 67 |
+
if (is_write.has_value()) {
|
| 68 |
+
TORCH_CHECK(*is_write == alias_info->isWrite(),
|
| 69 |
+
"Unsupported operator for ", op_name, " fallback: ", op.schema().name(),
|
| 70 |
+
op_name, " fallback doesn't work for operators with a mix "
|
| 71 |
+
"mutable and non-mutable inputs that alias with outputs, "
|
| 72 |
+
"this must be implemented manually. "
|
| 73 |
+
"If you got this error on a core op, please report a bug to PyTorch.");
|
| 74 |
+
} else {
|
| 75 |
+
is_write = alias_info->isWrite();
|
| 76 |
+
}
|
| 77 |
+
}
|
| 78 |
+
}
|
| 79 |
+
|
| 80 |
+
if (is_write.has_value() && !*is_write) {
|
| 81 |
+
// We assume that view operators automatically handle the math bit
|
| 82 |
+
// correctly by propagating the dispatch key in key_set.
|
| 83 |
+
// This is not necessarily always right, so you should test these cases.
|
| 84 |
+
op.redispatchBoxed(dispatch_keys & c10::DispatchKeySet(DispatchKeySet::FULL_AFTER, key), stack);
|
| 85 |
+
return;
|
| 86 |
+
}
|
| 87 |
+
|
| 88 |
+
// Mutable inputs with math bit set to True and their clones
|
| 89 |
+
std::vector<std::pair<Tensor, Tensor>> mutable_inputs_with_their_clones;
|
| 90 |
+
for (const auto i : c10::irange(num_arguments)) {
|
| 91 |
+
auto& ivalue = (*stack)[stack_start + i];
|
| 92 |
+
if (!(ivalue.isTensor() || ivalue.isTensorList())) {
|
| 93 |
+
continue;
|
| 94 |
+
}
|
| 95 |
+
const auto& argument = arguments[i];
|
| 96 |
+
bool mut_arg = false;
|
| 97 |
+
if (argument.alias_info()) {
|
| 98 |
+
// Was already tested by is_write loop above
|
| 99 |
+
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(argument.alias_info()->isWrite());
|
| 100 |
+
mut_arg = true;
|
| 101 |
+
}
|
| 102 |
+
if (ivalue.isTensor()) {
|
| 103 |
+
if (!is_bit_set(ivalue.toTensor())) {
|
| 104 |
+
continue;
|
| 105 |
+
}
|
| 106 |
+
auto tensor = std::move(ivalue).toTensor();
|
| 107 |
+
auto resolved_tensor = at::clone(tensor);
|
| 108 |
+
if (mut_arg) {
|
| 109 |
+
TORCH_CHECK(mutable_inputs_with_their_clones.empty(), op_name, " fallback does not support operators with more than one mutable tensors with ",
|
| 110 |
+
op_name, "bit set to true.");
|
| 111 |
+
mutable_inputs_with_their_clones.emplace_back(std::move(tensor), resolved_tensor);
|
| 112 |
+
}
|
| 113 |
+
(*stack)[stack_start + i] = std::move(resolved_tensor);
|
| 114 |
+
} else if (ivalue.isTensorList()) {
|
| 115 |
+
auto tensors = std::move(ivalue).toTensorList();
|
| 116 |
+
for(const auto j : c10::irange(tensors.size())) {
|
| 117 |
+
const auto& tensor = tensors[j];
|
| 118 |
+
if (!is_bit_set(tensor)) {
|
| 119 |
+
continue;
|
| 120 |
+
}
|
| 121 |
+
TORCH_CHECK(!mut_arg, " fallback doesn't currently support mutable TensorLists with ",
|
| 122 |
+
op_name, " inputs. Please materialize all the ", op_name, " input tensor(s) in the mutable TensorList inputs before calling ",
|
| 123 |
+
op.schema().name());
|
| 124 |
+
tensors[j] = at::clone(tensor);
|
| 125 |
+
}
|
| 126 |
+
(*stack)[stack_start + i] = std::move(tensors);
|
| 127 |
+
}
|
| 128 |
+
}
|
| 129 |
+
|
| 130 |
+
op.redispatchBoxed(dispatch_keys & c10::DispatchKeySet(DispatchKeySet::FULL_AFTER, key), stack);
|
| 131 |
+
|
| 132 |
+
TORCH_INTERNAL_ASSERT(mutable_inputs_with_their_clones.size() <= 1);
|
| 133 |
+
|
| 134 |
+
for (std::pair<Tensor, Tensor> mut_tensors: mutable_inputs_with_their_clones) {
|
| 135 |
+
auto& mutable_input = mut_tensors.first;
|
| 136 |
+
auto& cloned_mutable_input = mut_tensors.second;
|
| 137 |
+
auto& ivalue = (*stack)[stack_start];
|
| 138 |
+
auto returned_output = std::move(ivalue).toTensor();
|
| 139 |
+
|
| 140 |
+
// sanity check to ensure that the tensor in stack aliases the cloned_mutable_input
|
| 141 |
+
TORCH_INTERNAL_ASSERT(cloned_mutable_input.is_same(returned_output));
|
| 142 |
+
|
| 143 |
+
// necessary for out= arg
|
| 144 |
+
at::native::resize_output(mutable_input, returned_output.sizes());
|
| 145 |
+
|
| 146 |
+
mutable_input.copy_(returned_output);
|
| 147 |
+
(*stack)[stack_start] = std::move(mutable_input);
|
| 148 |
+
}
|
| 149 |
+
}
|
| 150 |
+
|
| 151 |
+
virtual ~MathOpFallback() = default;
|
| 152 |
+
|
| 153 |
+
DispatchKey key;
|
| 154 |
+
string op_name;
|
| 155 |
+
};
|
| 156 |
+
|
| 157 |
+
} // namespace at::native
|
.venv/lib/python3.11/site-packages/torch/include/ATen/native/MaxPooling.h
ADDED
|
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/core/Tensor.h>
|
| 4 |
+
#include <ATen/Parallel.h>
|
| 5 |
+
#include <ATen/native/DispatchStub.h>
|
| 6 |
+
#include <ATen/native/Pool.h>
|
| 7 |
+
|
| 8 |
+
namespace at::native {
|
| 9 |
+
|
| 10 |
+
inline void check_max_pool1d(
|
| 11 |
+
const Tensor& self,
|
| 12 |
+
IntArrayRef kernel_size,
|
| 13 |
+
IntArrayRef stride,
|
| 14 |
+
IntArrayRef padding,
|
| 15 |
+
IntArrayRef dilation,
|
| 16 |
+
bool ceil_mode) {
|
| 17 |
+
|
| 18 |
+
TORCH_CHECK(
|
| 19 |
+
self.dim() == 2 || self.dim() == 3,
|
| 20 |
+
"max_pool1d() Expected 2D or 3D input tensor, but got ", self.sym_sizes());
|
| 21 |
+
TORCH_CHECK(
|
| 22 |
+
kernel_size.size() == 1,
|
| 23 |
+
"max_pool1d() kernel_size must be an int, list of ints or tuple of ints of size 1 but got size ",
|
| 24 |
+
kernel_size.size());
|
| 25 |
+
TORCH_CHECK(
|
| 26 |
+
stride.empty() || stride.size() == 1,
|
| 27 |
+
"max_pool1d() stride must be None, an int, list of ints, or tuple of ints of size 1 but got size ",
|
| 28 |
+
stride.size());
|
| 29 |
+
TORCH_CHECK(
|
| 30 |
+
padding.size() == 1,
|
| 31 |
+
"max_pool1d() padding must be an int, list of ints, or tuple of ints of size 1 but got size ",
|
| 32 |
+
padding.size());
|
| 33 |
+
TORCH_CHECK(
|
| 34 |
+
dilation.size() == 1,
|
| 35 |
+
"max_pool1d() dilation must be an int, list of ints or tuple of ints of size 1 but got size ",
|
| 36 |
+
dilation.size());
|
| 37 |
+
|
| 38 |
+
// If stride=None then set it to kernel_size
|
| 39 |
+
if (stride.empty()) {
|
| 40 |
+
stride = kernel_size;
|
| 41 |
+
}
|
| 42 |
+
|
| 43 |
+
TORCH_CHECK(
|
| 44 |
+
kernel_size[0] > 0,
|
| 45 |
+
"max_pool1d() kernel_size must be greater than zero, but got ",
|
| 46 |
+
kernel_size[0]);
|
| 47 |
+
TORCH_CHECK(
|
| 48 |
+
stride[0] > 0, "max_pool1d() stride must be greater than zero, but got ", stride[0]);
|
| 49 |
+
TORCH_CHECK(
|
| 50 |
+
padding[0] >= 0, "max_pool1d() padding must be non-negative, but got ", padding[0]);
|
| 51 |
+
TORCH_CHECK(
|
| 52 |
+
padding[0] <= kernel_size[0] / 2,
|
| 53 |
+
"max_pool1d() padding should be at most half of kernel size, but got padding=",
|
| 54 |
+
padding[0],
|
| 55 |
+
" and kernel_size=",
|
| 56 |
+
kernel_size[0]);
|
| 57 |
+
TORCH_CHECK(
|
| 58 |
+
dilation[0] > 0, "max_pool1d() dilation must be greater than zero, but got ", dilation[0]);
|
| 59 |
+
|
| 60 |
+
const int64_t OW = pooling_output_shape(self.sym_size(-1).guard_int(__FILE__, __LINE__), kernel_size[0], padding[0], stride[0], dilation[0], ceil_mode);
|
| 61 |
+
TORCH_CHECK(OW > 0, "max_pool1d() Invalid computed output size: ", OW);
|
| 62 |
+
}
|
| 63 |
+
|
| 64 |
+
// TODO(Heitor) Template by dimension
|
| 65 |
+
struct PoolingParams1D {
|
| 66 |
+
int64_t NB; // Number of batches
|
| 67 |
+
int64_t NC; // Number of channels
|
| 68 |
+
int64_t IW; // Input width
|
| 69 |
+
int64_t OW; // Output width
|
| 70 |
+
int64_t KW; // Kernel width
|
| 71 |
+
int64_t SJ; // Column stride
|
| 72 |
+
int64_t PJ; // Column padding
|
| 73 |
+
int64_t DJ; // Column dilation
|
| 74 |
+
|
| 75 |
+
// Return index of input element for the given kernel and output index
|
| 76 |
+
inline int64_t index(int64_t kj, int64_t oj) const {
|
| 77 |
+
return oj * SJ + kj * DJ - PJ;
|
| 78 |
+
}
|
| 79 |
+
|
| 80 |
+
// Return index of first output within bounds for this kernel index
|
| 81 |
+
inline int64_t valid_output_start(int64_t kj) const {
|
| 82 |
+
int64_t ij = index(kj, 0);;
|
| 83 |
+
return ij < 0 ? at::divup(-ij, SJ) : 0;
|
| 84 |
+
}
|
| 85 |
+
|
| 86 |
+
// Return index one past last output within bounds for this kernel index
|
| 87 |
+
inline int64_t valid_output_end(int64_t kj) const {
|
| 88 |
+
int64_t ij = index(kj, OW - 1);
|
| 89 |
+
return ij >= IW ? OW - at::divup(ij - (IW - 1), SJ) : OW;
|
| 90 |
+
}
|
| 91 |
+
};
|
| 92 |
+
|
| 93 |
+
using pooling_fn = void (*)(Tensor&, const Tensor&, const PoolingParams1D&);
|
| 94 |
+
|
| 95 |
+
DECLARE_DISPATCH(pooling_fn, max_pool1d_stub);
|
| 96 |
+
|
| 97 |
+
} // namespace at::native
|
.venv/lib/python3.11/site-packages/torch/include/ATen/native/NonSymbolicBC.h
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
#include <ATen/core/Tensor.h>
|
| 3 |
+
#include <c10/util/irange.h>
|
| 4 |
+
#include <ATen/core/IListRef.h>
|
| 5 |
+
|
| 6 |
+
namespace at::native {
|
| 7 |
+
// This file contains non-symbolic signatures for ops that we have sym-intified the signature of.
|
| 8 |
+
// However, in certain cases (such as static runtime), we call the native versions of the ops directly.
|
| 9 |
+
// In those cases, we will duplicate the signature here with non-symbolic ints, and also duplicate the C++ implementation.
|
| 10 |
+
TORCH_API at::Tensor reshape(const at::Tensor& self, at::IntArrayRef proposed_shape);
|
| 11 |
+
TORCH_API at::Tensor narrow(const at::Tensor& self, int64_t dim, int64_t start, int64_t length);
|
| 12 |
+
TORCH_API at::Tensor _sparse_coo_tensor_unsafe(const at::Tensor & indices, const at::Tensor & values, at::IntArrayRef size, std::optional<at::ScalarType> dtype=std::nullopt, std::optional<at::Layout> layout=std::nullopt, std::optional<at::Device> device=std::nullopt, std::optional<bool> pin_memory=std::nullopt, std::optional<bool> is_coalesced=std::nullopt);
|
| 13 |
+
TORCH_API at::Tensor nll_loss(const at::Tensor & self, const at::Tensor & target, const std::optional<at::Tensor>& weight_opt, int64_t reduction, int64_t ignore_index);
|
| 14 |
+
TORCH_API at::Tensor nll_loss2d(const at::Tensor & self, const at::Tensor & target, const std::optional<at::Tensor>& weight_opt, int64_t reduction, int64_t ignore_index);
|
| 15 |
+
// The below ops don't get a duplicated C++ implementation.
|
| 16 |
+
// They are backward ops, which make them very unlikely to be called directly
|
| 17 |
+
// by external code (at::native::trace_backward).
|
| 18 |
+
// They get their own declaration for BC purposes however.
|
| 19 |
+
TORCH_API at::Tensor _embedding_bag_backward(const at::Tensor & grad, const at::Tensor & indices, const at::Tensor & offsets, const at::Tensor & offset2bag, const at::Tensor & bag_size, const at::Tensor & maximum_indices, int64_t num_weights, bool scale_grad_by_freq, int64_t mode, bool sparse, const std::optional<at::Tensor> & per_sample_weights, int64_t padding_idx=-1);
|
| 20 |
+
TORCH_API at::Tensor _embedding_bag_sparse_backward(const at::Tensor & grad, const at::Tensor & indices, const at::Tensor & offsets, const at::Tensor & offset2bag, const at::Tensor & bag_size, int64_t num_weights, bool scale_grad_by_freq, int64_t mode, const std::optional<at::Tensor> & per_sample_weights, int64_t padding_idx=-1);
|
| 21 |
+
TORCH_API at::Tensor value_selecting_reduction_backward(const at::Tensor & grad, int64_t dim, const at::Tensor & indices, at::IntArrayRef sizes, bool keepdim);
|
| 22 |
+
TORCH_API at::Tensor trace_backward(const at::Tensor & grad, at::IntArrayRef sizes);
|
| 23 |
+
TORCH_API at::Tensor index_select_backward(const at::Tensor & grad, at::IntArrayRef self_sizes, int64_t dim, const at::Tensor & index);
|
| 24 |
+
TORCH_API at::Tensor select(const at::Tensor& self, int64_t dim, int64_t index);
|
| 25 |
+
TORCH_API std::vector<Tensor> tensor_split(const Tensor& self, IntArrayRef indices, int64_t dim);
|
| 26 |
+
} // namespace at::native
|
.venv/lib/python3.11/site-packages/torch/include/ATen/native/RangeFactories.h
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <ATen/native/DispatchStub.h>
|
| 2 |
+
#include <c10/core/Scalar.h>
|
| 3 |
+
|
| 4 |
+
namespace at {
|
| 5 |
+
struct TensorIterator;
|
| 6 |
+
|
| 7 |
+
namespace native {
|
| 8 |
+
|
| 9 |
+
DECLARE_DISPATCH(void(*)(TensorIterator&, const Scalar&, const Scalar&, const Scalar&), arange_stub);
|
| 10 |
+
DECLARE_DISPATCH(void(*)(TensorIterator&, const Scalar&, const Scalar&, int64_t), linspace_stub);
|
| 11 |
+
|
| 12 |
+
}} // namespace at::native
|
.venv/lib/python3.11/site-packages/torch/include/ATen/native/ReduceAllOps.h
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/native/DispatchStub.h>
|
| 4 |
+
|
| 5 |
+
namespace at {
|
| 6 |
+
class Tensor;
|
| 7 |
+
}
|
| 8 |
+
|
| 9 |
+
namespace at::native {
|
| 10 |
+
|
| 11 |
+
using reduce_all_fn = void (*)(Tensor & result, const Tensor & self);
|
| 12 |
+
using reduce_min_max_fn = void (*)(Tensor & max_result, Tensor & min_result, const Tensor & self);
|
| 13 |
+
DECLARE_DISPATCH(reduce_all_fn, min_all_stub);
|
| 14 |
+
DECLARE_DISPATCH(reduce_all_fn, max_all_stub);
|
| 15 |
+
|
| 16 |
+
} // namespace at::native
|
.venv/lib/python3.11/site-packages/torch/include/ATen/native/ReductionType.h
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <c10/core/Scalar.h>
|
| 4 |
+
|
| 5 |
+
namespace at::native {
|
| 6 |
+
|
| 7 |
+
enum class ReductionType {MAX, MEAN, MIN, SUM, PROD};
|
| 8 |
+
|
| 9 |
+
inline ReductionType get_reduction_enum(const c10::string_view& reduce) {
|
| 10 |
+
if (reduce == "max" || reduce == "amax") {
|
| 11 |
+
return ReductionType::MAX;
|
| 12 |
+
} else if (reduce == "mean") {
|
| 13 |
+
return ReductionType::MEAN;
|
| 14 |
+
} else if (reduce == "min" || reduce == "amin") {
|
| 15 |
+
return ReductionType::MIN;
|
| 16 |
+
} else if (reduce == "sum") {
|
| 17 |
+
return ReductionType::SUM;
|
| 18 |
+
} else if (reduce == "prod") {
|
| 19 |
+
return ReductionType::PROD;
|
| 20 |
+
} else {
|
| 21 |
+
TORCH_CHECK(false, "reduce argument must be either sum, prod, mean, amax or amin, got ", reduce);
|
| 22 |
+
}
|
| 23 |
+
}
|
| 24 |
+
|
| 25 |
+
// used for `scatter_reduce`, old options for BC.
|
| 26 |
+
inline ReductionType get_operator_enum(const c10::string_view reduce, bool use_new_options) {
|
| 27 |
+
if (use_new_options) {
|
| 28 |
+
return get_reduction_enum(reduce);
|
| 29 |
+
} else {
|
| 30 |
+
if (reduce == "add") {
|
| 31 |
+
return ReductionType::SUM;
|
| 32 |
+
} else if (reduce == "multiply") {
|
| 33 |
+
return ReductionType::PROD;
|
| 34 |
+
} else {
|
| 35 |
+
TORCH_CHECK(false, "reduce argument must be either add or multiply.")
|
| 36 |
+
}
|
| 37 |
+
}
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
} // at::native
|
.venv/lib/python3.11/site-packages/torch/include/ATen/native/Repeat.h
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/core/Tensor.h>
|
| 4 |
+
#include <ATen/TensorOperators.h>
|
| 5 |
+
|
| 6 |
+
#ifndef AT_PER_OPERATOR_HEADERS
|
| 7 |
+
#include <ATen/Functions.h>
|
| 8 |
+
#else
|
| 9 |
+
#include <ATen/ops/empty.h>
|
| 10 |
+
#include <ATen/ops/empty_like.h>
|
| 11 |
+
#endif
|
| 12 |
+
|
| 13 |
+
namespace at::native {
|
| 14 |
+
|
| 15 |
+
template <
|
| 16 |
+
typename index_t,
|
| 17 |
+
void compute(const index_t*, const int64_t*, index_t*, int64_t, int64_t)>
|
| 18 |
+
static inline Tensor repeat_interleave_common(
|
| 19 |
+
const Tensor& repeats,
|
| 20 |
+
std::optional<int64_t> output_size) {
|
| 21 |
+
TORCH_CHECK(
|
| 22 |
+
repeats.dim() == 1, "repeat_interleave only accept 1D vector as repeat");
|
| 23 |
+
TORCH_CHECK(
|
| 24 |
+
repeats.scalar_type() == at::kLong || repeats.scalar_type() == at::kInt,
|
| 25 |
+
"repeats has to be Long or Int tensor");
|
| 26 |
+
if (repeats.size(0) == 0) {
|
| 27 |
+
return at::empty_like(repeats, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
|
| 28 |
+
}
|
| 29 |
+
Tensor repeats_ = repeats.contiguous();
|
| 30 |
+
Tensor cumsum = repeats.cumsum(0);
|
| 31 |
+
int64_t total = 0;
|
| 32 |
+
if (output_size.has_value()) {
|
| 33 |
+
total = output_size.value();
|
| 34 |
+
} else {
|
| 35 |
+
total = cumsum[-1].item<int64_t>();
|
| 36 |
+
TORCH_CHECK(
|
| 37 |
+
(repeats >= 0).all().item<uint8_t>(), "repeats can not be negative");
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
Tensor result = at::empty({total}, repeats.options());
|
| 41 |
+
const index_t* repeat_ptr = repeats_.const_data_ptr<index_t>();
|
| 42 |
+
const int64_t* cumsum_ptr = cumsum.const_data_ptr<int64_t>();
|
| 43 |
+
index_t* result_ptr = result.data_ptr<index_t>();
|
| 44 |
+
compute(repeat_ptr, cumsum_ptr, result_ptr, repeats.size(0), total);
|
| 45 |
+
return result;
|
| 46 |
+
}
|
| 47 |
+
|
| 48 |
+
} // namespace at::native
|
.venv/lib/python3.11/site-packages/torch/include/ATen/native/Resize.h
ADDED
|
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/core/Tensor.h>
|
| 4 |
+
#include <ATen/native/ResizeCommon.h>
|
| 5 |
+
#include <ATen/EmptyTensor.h>
|
| 6 |
+
#include <ATen/TensorUtils.h>
|
| 7 |
+
|
| 8 |
+
#include <c10/core/CPUAllocator.h>
|
| 9 |
+
|
| 10 |
+
#include <utility>
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
namespace at::native {
|
| 14 |
+
|
| 15 |
+
// TODO: make all operations that resize given outputs use this function
|
| 16 |
+
// for consistency and maintainability.
|
| 17 |
+
// Some operations like `cat` might not be able to make the use of
|
| 18 |
+
// resize_output directly. For more details to understand how it works in `cat`,
|
| 19 |
+
// see https://github.com/pytorch/pytorch/pull/62560#discussion_r687363362
|
| 20 |
+
// Resizes outputs
|
| 21 |
+
// Functions accepting output tensors, like with the "out" kwarg, should
|
| 22 |
+
// call this function to handle resizing their output tensor.
|
| 23 |
+
// Issues a warning if the output tensor has one or more elements and
|
| 24 |
+
// needs resizing
|
| 25 |
+
// NOTE: In the future the warning will become an error
|
| 26 |
+
// Returns a bool saying whether or not the resize actually happened or not
|
| 27 |
+
TORCH_API bool resize_output(const Tensor& output, IntArrayRef shape);
|
| 28 |
+
// WARNING: Do NOT call this directly. If you are resizing an output and want
|
| 29 |
+
// to support dynamic shapes call at::resize__symint and resize_output_check_symint.
|
| 30 |
+
// For more details, see: https://github.com/pytorch/pytorch/pull/111530/files#r1365845272
|
| 31 |
+
TORCH_API bool resize_output_symint(const Tensor& output, SymIntArrayRef shape);
|
| 32 |
+
|
| 33 |
+
// Utility for resize_output
|
| 34 |
+
// Returns a bool saying resize should happen or not and
|
| 35 |
+
// raises a warning if resizing for one or more elements
|
| 36 |
+
TORCH_API bool resize_output_check(const Tensor& output, IntArrayRef shape);
|
| 37 |
+
TORCH_API bool resize_output_check_symint(const Tensor& output, SymIntArrayRef shape);
|
| 38 |
+
|
| 39 |
+
TORCH_API void resize_bytes_cpu(StorageImpl* storage, size_t size_bytes);
|
| 40 |
+
TORCH_API void resize_bytes_meta(StorageImpl* storage, c10::SymInt size_bytes);
|
| 41 |
+
TORCH_API void resize_bytes_nocuda(const Storage& storage, const c10::SymInt& size_bytes);
|
| 42 |
+
|
| 43 |
+
inline void maybe_resize_storage_cpu(TensorImpl* self, size_t new_size_bytes) {
|
| 44 |
+
// It does not make sense to try to resize a storage
|
| 45 |
+
// to hold 0 elements, and this can break
|
| 46 |
+
// if storage_offset is positive but
|
| 47 |
+
// new_size is 0, so just bail in that case
|
| 48 |
+
// (same comment is in cuda/Resize.h)
|
| 49 |
+
if (self->numel() == 0) {
|
| 50 |
+
return;
|
| 51 |
+
}
|
| 52 |
+
|
| 53 |
+
const Storage& storage = self->unsafe_storage();
|
| 54 |
+
if (!storage) {
|
| 55 |
+
auto new_storage = c10::make_intrusive<StorageImpl>(
|
| 56 |
+
StorageImpl::use_byte_size_t(),
|
| 57 |
+
new_size_bytes,
|
| 58 |
+
c10::GetCPUAllocator(),
|
| 59 |
+
true);
|
| 60 |
+
self->set_storage_keep_dtype(std::move(new_storage));
|
| 61 |
+
} else if (new_size_bytes > storage.nbytes()) {
|
| 62 |
+
resize_bytes_cpu(storage.unsafeGetStorageImpl(), new_size_bytes);
|
| 63 |
+
}
|
| 64 |
+
}
|
| 65 |
+
|
| 66 |
+
TORCH_API TensorImpl* resize_impl_cpu_(
|
| 67 |
+
TensorImpl* self,
|
| 68 |
+
IntArrayRef size,
|
| 69 |
+
at::OptionalIntArrayRef stride,
|
| 70 |
+
bool resize_storage = true);
|
| 71 |
+
|
| 72 |
+
template <typename T>
|
| 73 |
+
T maybe_convert_symint(c10::SymInt) = delete;
|
| 74 |
+
|
| 75 |
+
template <>
|
| 76 |
+
inline c10::SymInt maybe_convert_symint(c10::SymInt x) { return x; }
|
| 77 |
+
|
| 78 |
+
template <>
|
| 79 |
+
inline int64_t maybe_convert_symint(c10::SymInt x) { return x.guard_int(__FILE__, __LINE__); }
|
| 80 |
+
|
| 81 |
+
template <typename T>
|
| 82 |
+
inline void checkInBoundsForStorage(
|
| 83 |
+
ArrayRef<T> size,
|
| 84 |
+
ArrayRef<T> stride,
|
| 85 |
+
T storage_offset,
|
| 86 |
+
const caffe2::TypeMeta& data_type,
|
| 87 |
+
const Storage& new_storage) {
|
| 88 |
+
T storage_size_bytes =
|
| 89 |
+
at::detail::computeStorageNbytes(size, stride, data_type.itemsize());
|
| 90 |
+
T storage_offset_bytes = storage_offset * data_type.itemsize();
|
| 91 |
+
if (storage_size_bytes == 0) {
|
| 92 |
+
// NB: (a tensor with arbitrary 0 dims)'s storage can have any numel.
|
| 93 |
+
return;
|
| 94 |
+
}
|
| 95 |
+
T new_storage_size_bytes = maybe_convert_symint<T>(new_storage.sym_nbytes());
|
| 96 |
+
TORCH_CHECK(
|
| 97 |
+
storage_size_bytes + storage_offset_bytes <= new_storage_size_bytes,
|
| 98 |
+
"setStorage: sizes ",
|
| 99 |
+
size,
|
| 100 |
+
", strides ",
|
| 101 |
+
stride,
|
| 102 |
+
","
|
| 103 |
+
" storage offset ",
|
| 104 |
+
storage_offset,
|
| 105 |
+
", and itemsize ",
|
| 106 |
+
data_type.itemsize(),
|
| 107 |
+
" requiring a storage size of ",
|
| 108 |
+
storage_size_bytes + storage_offset_bytes,
|
| 109 |
+
" are out of bounds for storage of size ",
|
| 110 |
+
new_storage_size_bytes);
|
| 111 |
+
}
|
| 112 |
+
|
| 113 |
+
template <typename T>
|
| 114 |
+
inline void checkSetStorage(Tensor& result, Storage storage, T storage_offset,
|
| 115 |
+
ArrayRef<T> size, ArrayRef<T> stride) {
|
| 116 |
+
// FIXME: stride should be optional
|
| 117 |
+
if (stride.data()) {
|
| 118 |
+
TORCH_CHECK(size.size() == stride.size(), "unequal size length (", size.size(),
|
| 119 |
+
") and stride length (", stride.size(), ")");
|
| 120 |
+
}
|
| 121 |
+
|
| 122 |
+
#ifdef DEBUG
|
| 123 |
+
TORCH_CHECK(size.size() <= INT_MAX, "size length (", size.size(), ") greater than INT_MAX");
|
| 124 |
+
#endif
|
| 125 |
+
|
| 126 |
+
// storage: note this can't be replaced with result.set_(storage) as the semantics of that
|
| 127 |
+
// function is to set the tensor size to be equal to the size of the storage.
|
| 128 |
+
if (!result.storage().is_alias_of(storage)) {
|
| 129 |
+
// Caffe2 might have tensors whose storages are null, but we
|
| 130 |
+
// don't allow it in PyTorch.
|
| 131 |
+
TORCH_INTERNAL_ASSERT(storage);
|
| 132 |
+
TORCH_INTERNAL_ASSERT(result.storage());
|
| 133 |
+
|
| 134 |
+
// We used to allow this, but this breaks device caching.
|
| 135 |
+
// Let's put an actual error message for this one.
|
| 136 |
+
TORCH_CHECK(result.storage().device() == storage.device(),
|
| 137 |
+
"Attempted to set the storage of a tensor on device \"", result.storage().device(),
|
| 138 |
+
"\" to a storage on different device \"", storage.device(),
|
| 139 |
+
"\". This is no longer allowed; the devices must match.");
|
| 140 |
+
result.unsafeGetTensorImpl()->set_storage_keep_dtype(std::move(storage));
|
| 141 |
+
}
|
| 142 |
+
|
| 143 |
+
// storageOffset
|
| 144 |
+
TORCH_CHECK(storage_offset >= 0, "Tensor: invalid storage offset ", storage_offset);
|
| 145 |
+
}
|
| 146 |
+
|
| 147 |
+
/**
|
| 148 |
+
* Set self's sizes, strides, and storage_offset.
|
| 149 |
+
* (size, stride, storage_offset) must be in bounds for self's storage.
|
| 150 |
+
*/
|
| 151 |
+
template <typename T>
|
| 152 |
+
inline void setStrided(
|
| 153 |
+
const Tensor& self,
|
| 154 |
+
ArrayRef<T> size,
|
| 155 |
+
ArrayRef<T> stride,
|
| 156 |
+
T storage_offset) {
|
| 157 |
+
TORCH_CHECK(size.size() == stride.size(), "mismatch in length of strides and shape");
|
| 158 |
+
for (const auto& val : stride) {
|
| 159 |
+
TORCH_CHECK(val >= 0,
|
| 160 |
+
"as_strided: Negative strides are not supported at the moment, "
|
| 161 |
+
"got strides: ", stride);
|
| 162 |
+
}
|
| 163 |
+
|
| 164 |
+
auto* self_ = self.unsafeGetTensorImpl();
|
| 165 |
+
checkInBoundsForStorage(
|
| 166 |
+
size, stride, storage_offset, self_->dtype(), self_->storage());
|
| 167 |
+
|
| 168 |
+
/* storage offset */
|
| 169 |
+
TORCH_CHECK(storage_offset >= 0, "Tensor: invalid storage offset ", storage_offset);
|
| 170 |
+
self_->set_sizes_and_strides(size, stride, std::make_optional(storage_offset));
|
| 171 |
+
}
|
| 172 |
+
|
| 173 |
+
} // namespace at::native
|
.venv/lib/python3.11/site-packages/torch/include/ATen/native/ResizeCommon.h
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/core/Tensor.h>
|
| 4 |
+
#include <ATen/native/TensorFactories.h>
|
| 5 |
+
#include <ATen/NamedTensorUtils.h>
|
| 6 |
+
#include <c10/util/irange.h>
|
| 7 |
+
|
| 8 |
+
#ifndef AT_PER_OPERATOR_HEADERS
|
| 9 |
+
#include <ATen/NativeFunctions.h>
|
| 10 |
+
#else
|
| 11 |
+
#include <ATen/ops/empty.h>
|
| 12 |
+
#endif
|
| 13 |
+
|
| 14 |
+
namespace at::native {
|
| 15 |
+
|
| 16 |
+
template <typename T>
|
| 17 |
+
inline T storage_size_for(ArrayRef<T> size, ArrayRef<T> stride) {
|
| 18 |
+
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(size.size() == stride.size(),
|
| 19 |
+
"storage_size_for(size, stride) requires that size and stride ",
|
| 20 |
+
"have the same size as a precondition.");
|
| 21 |
+
T storage_size = 1;
|
| 22 |
+
for (const auto dim : c10::irange(size.size())) {
|
| 23 |
+
if (size[dim] == 0) {
|
| 24 |
+
storage_size = 0;
|
| 25 |
+
break;
|
| 26 |
+
}
|
| 27 |
+
storage_size += (size[dim] - 1) * stride[dim];
|
| 28 |
+
}
|
| 29 |
+
return storage_size;
|
| 30 |
+
}
|
| 31 |
+
|
| 32 |
+
inline const Tensor& resize_named_tensor_(
|
| 33 |
+
const Tensor& self,
|
| 34 |
+
IntArrayRef size,
|
| 35 |
+
std::optional<MemoryFormat> optional_memory_format) {
|
| 36 |
+
TORCH_INTERNAL_ASSERT(self.has_names());
|
| 37 |
+
TORCH_CHECK(
|
| 38 |
+
self.sizes() == size,
|
| 39 |
+
"Cannot resize named tensor with resize_ or resize_as_ (tried to resize "
|
| 40 |
+
"Tensor",
|
| 41 |
+
self.names(),
|
| 42 |
+
" with size ",
|
| 43 |
+
self.sizes(),
|
| 44 |
+
" to ",
|
| 45 |
+
size,
|
| 46 |
+
"). This may be caused by passing a named tensor ",
|
| 47 |
+
"as an `out=` argument; please ensure that the sizes are the same. ");
|
| 48 |
+
TORCH_CHECK(
|
| 49 |
+
!optional_memory_format.has_value(),
|
| 50 |
+
"Unsupported memory format for named tensor resize ",
|
| 51 |
+
optional_memory_format.value());
|
| 52 |
+
return self;
|
| 53 |
+
}
|
| 54 |
+
|
| 55 |
+
// For deterministic output, fill new elements that were added after a storage
|
| 56 |
+
// resize with NaN or MAX_INT. `old_storage_nbytes` is the size of the storage
|
| 57 |
+
// before the resize happened.
|
| 58 |
+
inline const Tensor& fill_resize_deterministic_(const Tensor& tensor, int64_t old_storage_nbytes) {
|
| 59 |
+
const at::Storage& storage = tensor.unsafeGetTensorImpl()->unsafe_storage();
|
| 60 |
+
int64_t new_storage_nbytes = storage.nbytes();
|
| 61 |
+
int64_t old_storage_numel = old_storage_nbytes / tensor.itemsize();
|
| 62 |
+
int64_t new_storage_numel = new_storage_nbytes / tensor.itemsize();
|
| 63 |
+
if (new_storage_numel > old_storage_numel) {
|
| 64 |
+
at::Tensor tensor_view = at::empty({}, at::TensorOptions().dtype(tensor.scalar_type()).device(tensor.device()));
|
| 65 |
+
tensor_view.set_(
|
| 66 |
+
storage,
|
| 67 |
+
/*storage_offset=*/old_storage_numel,
|
| 68 |
+
/*size=*/{new_storage_numel - old_storage_numel},
|
| 69 |
+
/*stride=*/{1});
|
| 70 |
+
at::native::fill_empty_deterministic_(tensor_view);
|
| 71 |
+
}
|
| 72 |
+
return tensor;
|
| 73 |
+
}
|
| 74 |
+
|
| 75 |
+
} // namespace at::native
|
.venv/lib/python3.11/site-packages/torch/include/ATen/native/ScatterGatherChecks.h
ADDED
|
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <vector>
|
| 4 |
+
#include <ATen/core/Tensor.h>
|
| 5 |
+
#include <ATen/native/ReduceOpsUtils.h>
|
| 6 |
+
#include <c10/util/irange.h>
|
| 7 |
+
|
| 8 |
+
namespace at::native {
|
| 9 |
+
|
| 10 |
+
namespace {
|
| 11 |
+
|
| 12 |
+
// checks whether index.dtype == int64
|
| 13 |
+
// and self.dtype == src.dtype if src is a Tensor
|
| 14 |
+
inline void scatter_gather_dtype_check(
|
| 15 |
+
const std::string& method_name,
|
| 16 |
+
const Tensor& self,
|
| 17 |
+
const Tensor& index,
|
| 18 |
+
const std::optional<Tensor>& src_opt = std::nullopt
|
| 19 |
+
) {
|
| 20 |
+
if (index.numel() != 0) {
|
| 21 |
+
TORCH_CHECK(
|
| 22 |
+
index.scalar_type() == at::ScalarType::Long,
|
| 23 |
+
method_name, "(): Expected dtype int64 for index"
|
| 24 |
+
);
|
| 25 |
+
}
|
| 26 |
+
|
| 27 |
+
if (src_opt.has_value()) {
|
| 28 |
+
const auto& src = src_opt.value();
|
| 29 |
+
TORCH_CHECK(
|
| 30 |
+
self.scalar_type() == src.scalar_type(),
|
| 31 |
+
method_name, "(): Expected self.dtype to be equal to src.dtype"
|
| 32 |
+
);
|
| 33 |
+
}
|
| 34 |
+
}
|
| 35 |
+
|
| 36 |
+
// Used for `gather`-like methods
|
| 37 |
+
// Note: self means the input tensor here
|
| 38 |
+
// Test:
|
| 39 |
+
// 1. index.size(d) <= self.size(d) for all d != dim
|
| 40 |
+
// 2. index.dim() == self.dim()
|
| 41 |
+
inline void gather_shape_check(const Tensor& self, int64_t dim,
|
| 42 |
+
const Tensor& index
|
| 43 |
+
) {
|
| 44 |
+
auto self_dims = ensure_nonempty_dim(self.dim());
|
| 45 |
+
TORCH_CHECK(self_dims == ensure_nonempty_dim(index.dim()),
|
| 46 |
+
"Index tensor must have the same number of dimensions as input tensor"
|
| 47 |
+
);
|
| 48 |
+
|
| 49 |
+
for (const auto i : c10::irange(self_dims)) {
|
| 50 |
+
if (i != dim) {
|
| 51 |
+
TORCH_CHECK(
|
| 52 |
+
ensure_nonempty_size(index, i) <= ensure_nonempty_size(self, i),
|
| 53 |
+
"Size does not match at dimension ", i,
|
| 54 |
+
" expected index ", index.sizes(),
|
| 55 |
+
" to be smaller than self ", self.sizes(),
|
| 56 |
+
" apart from dimension ", dim
|
| 57 |
+
);
|
| 58 |
+
}
|
| 59 |
+
}
|
| 60 |
+
}
|
| 61 |
+
|
| 62 |
+
// Used for `scatter` and `scatter_add`
|
| 63 |
+
// Tests:
|
| 64 |
+
// 1. index.size(d) <= self.size(d) for all d != dim
|
| 65 |
+
// 2. index.size(d) <= src.size(d) for all d if src is a Tensor
|
| 66 |
+
// 3. index.dim() == self.dim() == src.dim()
|
| 67 |
+
inline void scatter_shape_check(
|
| 68 |
+
const Tensor& self, int64_t dim, const Tensor& index,
|
| 69 |
+
const std::optional<Tensor>& src_opt = std::nullopt
|
| 70 |
+
) {
|
| 71 |
+
if (index.numel() == 0) return;
|
| 72 |
+
TORCH_CHECK(
|
| 73 |
+
ensure_nonempty_dim(self.dim()) == ensure_nonempty_dim(index.dim()),
|
| 74 |
+
"Index tensor must have the same number of dimensions as self tensor"
|
| 75 |
+
);
|
| 76 |
+
|
| 77 |
+
bool is_wrong_shape = false;
|
| 78 |
+
int64_t self_dims = ensure_nonempty_dim(self.dim());
|
| 79 |
+
|
| 80 |
+
// Check: index.size(d) <= self.size(d) for all d != dim
|
| 81 |
+
for (const auto d : c10::irange(self_dims)) {
|
| 82 |
+
int64_t index_d_size = ensure_nonempty_size(index, d);
|
| 83 |
+
if (d == dim) continue;
|
| 84 |
+
if (index_d_size > ensure_nonempty_size(self, d)) {
|
| 85 |
+
is_wrong_shape = true;
|
| 86 |
+
break;
|
| 87 |
+
}
|
| 88 |
+
}
|
| 89 |
+
|
| 90 |
+
// Check: index.size(d) <= src.size(d) for all d if src is Tensor
|
| 91 |
+
if (!is_wrong_shape && src_opt.has_value()) {
|
| 92 |
+
const auto& src = src_opt.value();
|
| 93 |
+
for (const auto d : c10::irange(self_dims)) {
|
| 94 |
+
int64_t index_d_size = ensure_nonempty_size(index, d);
|
| 95 |
+
if (index_d_size > ensure_nonempty_size(src, d)) {
|
| 96 |
+
is_wrong_shape = true;
|
| 97 |
+
break;
|
| 98 |
+
}
|
| 99 |
+
}
|
| 100 |
+
}
|
| 101 |
+
|
| 102 |
+
if (src_opt.has_value()) {
|
| 103 |
+
const auto& src = src_opt.value();
|
| 104 |
+
|
| 105 |
+
TORCH_CHECK(
|
| 106 |
+
ensure_nonempty_dim(src.dim()) == ensure_nonempty_dim(index.dim()),
|
| 107 |
+
"Index tensor must have the same number of dimensions as src tensor"
|
| 108 |
+
);
|
| 109 |
+
|
| 110 |
+
TORCH_CHECK(!is_wrong_shape,
|
| 111 |
+
"Expected index ", index.sizes(),
|
| 112 |
+
" to be smaller than self ", self.sizes(),
|
| 113 |
+
" apart from dimension ", dim,
|
| 114 |
+
" and to be smaller size than src ", src.sizes()
|
| 115 |
+
);
|
| 116 |
+
}
|
| 117 |
+
else {
|
| 118 |
+
TORCH_CHECK(!is_wrong_shape,
|
| 119 |
+
"Expected index ", index.sizes(),
|
| 120 |
+
" to be smaller than self ", self.sizes(),
|
| 121 |
+
" apart from dimension ", dim
|
| 122 |
+
);
|
| 123 |
+
}
|
| 124 |
+
}
|
| 125 |
+
|
| 126 |
+
} // anonymous namespace
|
| 127 |
+
|
| 128 |
+
} // namespace at::native
|
.venv/lib/python3.11/site-packages/torch/include/ATen/native/SharedReduceOps.h
ADDED
|
@@ -0,0 +1,544 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
// Please note that this file is
|
| 3 |
+
// used across both CPU and GPU.
|
| 4 |
+
|
| 5 |
+
#include <type_traits>
|
| 6 |
+
#include <complex>
|
| 7 |
+
#include <c10/macros/Macros.h>
|
| 8 |
+
#include <ATen/detail/FunctionTraits.h>
|
| 9 |
+
#include <ATen/NumericUtils.h>
|
| 10 |
+
#if defined(__CUDACC__)
|
| 11 |
+
#include <ATen/cuda/DeviceUtils.cuh>
|
| 12 |
+
#include <ATen/native/cuda/DeviceSqrt.cuh>
|
| 13 |
+
#elif defined(__HIPCC__)
|
| 14 |
+
#include <ATen/hip/DeviceUtils.cuh>
|
| 15 |
+
#include <ATen/native/hip/DeviceSqrt.cuh>
|
| 16 |
+
#endif
|
| 17 |
+
#if defined(__CUDACC__) || defined(__HIPCC__)
|
| 18 |
+
#include <thrust/pair.h>
|
| 19 |
+
#else
|
| 20 |
+
#include <cmath>
|
| 21 |
+
#define device_sqrt std::sqrt
|
| 22 |
+
#endif
|
| 23 |
+
#if defined(__CUDACC__) || defined(__HIPCC__)
|
| 24 |
+
template <typename scalar_t>
|
| 25 |
+
inline C10_DEVICE scalar_t max_propagate_nan(scalar_t a, scalar_t b) {
|
| 26 |
+
#if defined(__HIPCC__)
|
| 27 |
+
// TODO: remove this special case for HIP when issue is fixed:
|
| 28 |
+
// https://github.com/ROCm-Developer-Tools/HIP/issues/2209
|
| 29 |
+
scalar_t max = at::_isnan(a) ? a : (at::_isnan(b) ? b : std::max(a, b));
|
| 30 |
+
#else
|
| 31 |
+
scalar_t max = at::_isnan(b) ? b : std::max(a, b);
|
| 32 |
+
#endif
|
| 33 |
+
return max;
|
| 34 |
+
}
|
| 35 |
+
template <typename scalar_t>
|
| 36 |
+
inline C10_DEVICE scalar_t min_propagate_nan(scalar_t a, scalar_t b) {
|
| 37 |
+
#if defined(__HIPCC__)
|
| 38 |
+
// TODO: remove this special case for HIP when issue is fixed:
|
| 39 |
+
// https://github.com/ROCm-Developer-Tools/HIP/issues/2209
|
| 40 |
+
scalar_t min = at::_isnan(a) ? a : (at::_isnan(b) ? b : std::min(a, b));
|
| 41 |
+
#else
|
| 42 |
+
scalar_t min = at::_isnan(b) ? b : std::min(a, b);
|
| 43 |
+
#endif
|
| 44 |
+
return min;
|
| 45 |
+
}
|
| 46 |
+
#define MAX(X, Y) max_propagate_nan(X,Y)
|
| 47 |
+
#define MIN(X, Y) min_propagate_nan(X,Y)
|
| 48 |
+
#else
|
| 49 |
+
#include <ATen/native/cpu/zmath.h>
|
| 50 |
+
#define MAX(X, Y) max_impl(X,Y)
|
| 51 |
+
#define MIN(X, Y) min_impl(X,Y)
|
| 52 |
+
#endif
|
| 53 |
+
|
| 54 |
+
// ROCM hcc doesn't work well with using std:: in kernel functions
|
| 55 |
+
#if defined(__CUDA_ARCH__)
|
| 56 |
+
#include <c10/cuda/CUDAMathCompat.h>
|
| 57 |
+
#define compat_pow c10::cuda::compat::pow
|
| 58 |
+
#elif defined(__HIPCC__)
|
| 59 |
+
#include <c10/hip/HIPMathCompat.h>
|
| 60 |
+
#define compat_pow c10::hip::compat::pow
|
| 61 |
+
#else
|
| 62 |
+
#define compat_pow std::pow
|
| 63 |
+
#endif
|
| 64 |
+
|
| 65 |
+
namespace at { namespace native {
|
| 66 |
+
|
| 67 |
+
namespace detail {
|
| 68 |
+
|
| 69 |
+
#if defined(__CUDACC__) || defined(__HIPCC__)
|
| 70 |
+
template <typename T1, typename T2> using pair = thrust::pair<T1, T2>;
|
| 71 |
+
#else
|
| 72 |
+
template <typename T1, typename T2> using pair = std::pair<T1, T2>;
|
| 73 |
+
#endif
|
| 74 |
+
|
| 75 |
+
} // namespace detail
|
| 76 |
+
|
| 77 |
+
template <typename scalar_t, typename index_t>
|
| 78 |
+
struct WelfordData {
|
| 79 |
+
scalar_t mean;
|
| 80 |
+
scalar_t m2;
|
| 81 |
+
index_t n;
|
| 82 |
+
scalar_t nf;
|
| 83 |
+
|
| 84 |
+
C10_HOST_DEVICE WelfordData() : mean(0), m2(0), n(0), nf(0) {}
|
| 85 |
+
|
| 86 |
+
C10_HOST_DEVICE WelfordData(
|
| 87 |
+
scalar_t mean,
|
| 88 |
+
scalar_t m2,
|
| 89 |
+
index_t n,
|
| 90 |
+
scalar_t nf)
|
| 91 |
+
: mean(mean), m2(m2), n(n), nf(nf) {}
|
| 92 |
+
};
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
template <typename scalar_t, typename acc_scalar_t, typename index_t, typename res_t>
|
| 96 |
+
struct WelfordOps {
|
| 97 |
+
acc_scalar_t correction;
|
| 98 |
+
bool take_sqrt;
|
| 99 |
+
public:
|
| 100 |
+
using acc_t = WelfordData<acc_scalar_t, index_t>;
|
| 101 |
+
inline C10_DEVICE acc_t reduce(acc_t acc, scalar_t data, index_t /*idx*/) const {
|
| 102 |
+
// We accumulate n in index_t to avoid cumulative rounding error, but still
|
| 103 |
+
// need nf for use in combine where int32 may overflow.
|
| 104 |
+
index_t new_n = acc.n + 1;
|
| 105 |
+
acc_scalar_t new_nf = static_cast<acc_scalar_t>(new_n);
|
| 106 |
+
acc_scalar_t delta = data - acc.mean;
|
| 107 |
+
acc_scalar_t new_mean = acc.mean + delta / new_nf;
|
| 108 |
+
acc_scalar_t new_delta = data - new_mean;
|
| 109 |
+
return {
|
| 110 |
+
new_mean,
|
| 111 |
+
acc.m2 + delta * new_delta,
|
| 112 |
+
new_n,
|
| 113 |
+
new_nf,
|
| 114 |
+
};
|
| 115 |
+
}
|
| 116 |
+
inline C10_DEVICE acc_t combine(acc_t a, acc_t b) const {
|
| 117 |
+
if (a.nf == 0) {
|
| 118 |
+
return b;
|
| 119 |
+
}
|
| 120 |
+
if (b.nf == 0) {
|
| 121 |
+
return a;
|
| 122 |
+
}
|
| 123 |
+
acc_scalar_t delta = b.mean - a.mean;
|
| 124 |
+
acc_scalar_t new_count = a.nf + b.nf;
|
| 125 |
+
acc_scalar_t nb_over_n = b.nf / new_count;
|
| 126 |
+
return {
|
| 127 |
+
a.mean + delta * nb_over_n,
|
| 128 |
+
a.m2 + b.m2 + delta * delta * a.nf * nb_over_n,
|
| 129 |
+
// setting acc.n as -1 since acc.n might not be able to represent the count
|
| 130 |
+
// correctly within its range, setting it to -1 to avoid confusion
|
| 131 |
+
-1,
|
| 132 |
+
new_count
|
| 133 |
+
};
|
| 134 |
+
}
|
| 135 |
+
inline C10_DEVICE res_t project(acc_t acc) const __ubsan_ignore_float_divide_by_zero__ {
|
| 136 |
+
const auto mean = static_cast<scalar_t>(acc.mean);
|
| 137 |
+
const auto divisor = acc.nf > correction ? acc.nf - correction : 0;
|
| 138 |
+
const auto var = acc.m2 / divisor;
|
| 139 |
+
res_t results(take_sqrt ? device_sqrt(var) : var, mean);
|
| 140 |
+
return results;
|
| 141 |
+
}
|
| 142 |
+
|
| 143 |
+
static C10_DEVICE acc_t translate_idx(acc_t acc, int64_t /*base_idx*/) {
|
| 144 |
+
return acc;
|
| 145 |
+
}
|
| 146 |
+
|
| 147 |
+
#if defined(__CUDACC__) || defined(__HIPCC__)
|
| 148 |
+
inline __device__ acc_t warp_shfl_down(acc_t acc, int offset) const {
|
| 149 |
+
return {
|
| 150 |
+
WARP_SHFL_DOWN(acc.mean, offset)
|
| 151 |
+
, WARP_SHFL_DOWN(acc.m2, offset)
|
| 152 |
+
, WARP_SHFL_DOWN(acc.n, offset)
|
| 153 |
+
, WARP_SHFL_DOWN(acc.nf, offset)
|
| 154 |
+
};
|
| 155 |
+
}
|
| 156 |
+
#endif
|
| 157 |
+
C10_HOST_DEVICE WelfordOps(acc_scalar_t correction, bool take_sqrt)
|
| 158 |
+
: correction(correction), take_sqrt(take_sqrt) {}
|
| 159 |
+
};
|
| 160 |
+
|
| 161 |
+
template <typename scalar_t, typename acc_t=scalar_t, typename factor_t=acc_t, typename out_t = acc_t>
|
| 162 |
+
struct MeanOps {
|
| 163 |
+
factor_t factor;
|
| 164 |
+
|
| 165 |
+
inline C10_DEVICE acc_t reduce(acc_t a, scalar_t b, int64_t /*idx*/) const {
|
| 166 |
+
return combine(a, static_cast<acc_t>(b));
|
| 167 |
+
}
|
| 168 |
+
|
| 169 |
+
inline C10_DEVICE acc_t combine(acc_t a, acc_t b) const {
|
| 170 |
+
return a + b;
|
| 171 |
+
}
|
| 172 |
+
|
| 173 |
+
inline C10_DEVICE out_t project(acc_t a) const {
|
| 174 |
+
return a * factor;
|
| 175 |
+
}
|
| 176 |
+
|
| 177 |
+
static C10_DEVICE acc_t translate_idx(acc_t acc, int64_t /*base_idx*/) {
|
| 178 |
+
return acc;
|
| 179 |
+
}
|
| 180 |
+
|
| 181 |
+
#if defined(__CUDACC__) || defined(__HIPCC__)
|
| 182 |
+
inline C10_DEVICE acc_t warp_shfl_down(acc_t data, int offset) const {
|
| 183 |
+
return WARP_SHFL_DOWN(data, offset);
|
| 184 |
+
}
|
| 185 |
+
#endif
|
| 186 |
+
|
| 187 |
+
MeanOps(factor_t factor): factor(factor) {
|
| 188 |
+
}
|
| 189 |
+
};
|
| 190 |
+
|
| 191 |
+
// This accumulator template is used to calculate the minimum absolute value of
|
| 192 |
+
// a set of numbers.
|
| 193 |
+
// `scalar_t` is the type of the input and `acc_t` is the type of the accumulated
|
| 194 |
+
// value. These types differ for complex number input support.
|
| 195 |
+
template <typename scalar_t, typename acc_t = scalar_t, typename out_t = acc_t>
|
| 196 |
+
struct AbsMinOps {
|
| 197 |
+
|
| 198 |
+
inline C10_DEVICE acc_t reduce(acc_t acc, scalar_t data, int64_t /*idx*/) const {
|
| 199 |
+
return MIN(acc, static_cast<acc_t>(std::abs(data)));
|
| 200 |
+
}
|
| 201 |
+
|
| 202 |
+
inline C10_DEVICE acc_t combine(acc_t a, acc_t b) const {
|
| 203 |
+
return MIN(a, b);
|
| 204 |
+
}
|
| 205 |
+
|
| 206 |
+
inline C10_DEVICE out_t project(acc_t a) const {
|
| 207 |
+
return a;
|
| 208 |
+
}
|
| 209 |
+
|
| 210 |
+
static C10_DEVICE acc_t translate_idx(acc_t acc, int64_t /*base_idx*/) {
|
| 211 |
+
return acc;
|
| 212 |
+
}
|
| 213 |
+
|
| 214 |
+
#if defined(__CUDACC__) || defined(__HIPCC__)
|
| 215 |
+
inline C10_DEVICE acc_t warp_shfl_down(acc_t acc, int offset) const {
|
| 216 |
+
return WARP_SHFL_DOWN(acc, offset);
|
| 217 |
+
}
|
| 218 |
+
#endif
|
| 219 |
+
};
|
| 220 |
+
|
| 221 |
+
// This accumulator template is used to calculate the maximum absolute value of
|
| 222 |
+
// a set of numbers.
|
| 223 |
+
// `scalar_t` is the type of the input and `acc_t` is the type of the accumulated
|
| 224 |
+
// value. These types differ for complex number input support.
|
| 225 |
+
template <typename scalar_t, typename acc_t = scalar_t, typename out_t = acc_t>
|
| 226 |
+
struct AbsMaxOps {
|
| 227 |
+
inline C10_DEVICE acc_t reduce(acc_t acc, scalar_t data, int64_t /*idx*/) const {
|
| 228 |
+
return MAX(acc, static_cast<acc_t>(std::abs(data)));
|
| 229 |
+
}
|
| 230 |
+
|
| 231 |
+
inline C10_DEVICE acc_t combine(acc_t a, acc_t b) const {
|
| 232 |
+
return MAX(a, b);
|
| 233 |
+
}
|
| 234 |
+
|
| 235 |
+
inline C10_DEVICE out_t project(acc_t a) const {
|
| 236 |
+
return a;
|
| 237 |
+
}
|
| 238 |
+
|
| 239 |
+
static C10_DEVICE acc_t translate_idx(acc_t acc, int64_t /*base_idx*/) {
|
| 240 |
+
return acc;
|
| 241 |
+
}
|
| 242 |
+
|
| 243 |
+
#if defined(__CUDACC__) || defined(__HIPCC__)
|
| 244 |
+
inline C10_DEVICE acc_t warp_shfl_down(acc_t acc, int offset) const {
|
| 245 |
+
return WARP_SHFL_DOWN(acc, offset);
|
| 246 |
+
}
|
| 247 |
+
#endif
|
| 248 |
+
};
|
| 249 |
+
|
| 250 |
+
// This accumulator template is used to calculate the norm of the absolute value
|
| 251 |
+
// of a set of numbers.
|
| 252 |
+
// `scalar_t` is the type of the input and `acc_t` is the type of the accumulated
|
| 253 |
+
// value. These types differ for complex number input support.
|
| 254 |
+
template <typename scalar_t, typename acc_t = scalar_t, typename out_t = acc_t>
|
| 255 |
+
struct NormOps {
|
| 256 |
+
acc_t norm_;
|
| 257 |
+
|
| 258 |
+
inline C10_DEVICE acc_t reduce(acc_t acc, scalar_t data, int64_t /*idx*/) const {
|
| 259 |
+
return acc + compat_pow(static_cast<acc_t>(std::abs(data)), norm_);
|
| 260 |
+
}
|
| 261 |
+
|
| 262 |
+
inline C10_DEVICE acc_t combine(acc_t a, acc_t b) const {
|
| 263 |
+
return a + b;
|
| 264 |
+
}
|
| 265 |
+
|
| 266 |
+
inline C10_DEVICE out_t project(acc_t a) const {
|
| 267 |
+
return compat_pow(a, static_cast<acc_t>(1.0) / norm_);
|
| 268 |
+
}
|
| 269 |
+
|
| 270 |
+
static C10_DEVICE acc_t translate_idx(acc_t acc, int64_t /*base_idx*/) {
|
| 271 |
+
return acc;
|
| 272 |
+
}
|
| 273 |
+
|
| 274 |
+
#if defined(__CUDACC__) || defined(__HIPCC__)
|
| 275 |
+
inline C10_DEVICE acc_t warp_shfl_down(acc_t acc, int offset) const {
|
| 276 |
+
return WARP_SHFL_DOWN(acc, offset);
|
| 277 |
+
}
|
| 278 |
+
#endif
|
| 279 |
+
|
| 280 |
+
NormOps(acc_t norm_): norm_(norm_) {
|
| 281 |
+
}
|
| 282 |
+
};
|
| 283 |
+
|
| 284 |
+
// This accumulator template is used to calculate the order zero norm of the
|
| 285 |
+
// absolute value of a set of numbers.
|
| 286 |
+
// `scalar_t` is the type of the input and `acc_t` is the type of the accumulated
|
| 287 |
+
// value. These types differ for complex number input support.
|
| 288 |
+
template <typename scalar_t, typename acc_t = scalar_t, typename out_t = acc_t>
|
| 289 |
+
struct NormZeroOps {
|
| 290 |
+
inline C10_DEVICE acc_t reduce(acc_t acc, scalar_t data, int64_t /*idx*/) const {
|
| 291 |
+
return acc + (data == static_cast<scalar_t>(0) ? static_cast<acc_t>(0) : static_cast<acc_t>(1));
|
| 292 |
+
}
|
| 293 |
+
|
| 294 |
+
inline C10_DEVICE acc_t combine(acc_t a, acc_t b) const {
|
| 295 |
+
return a + b;
|
| 296 |
+
}
|
| 297 |
+
|
| 298 |
+
inline C10_DEVICE out_t project(acc_t a) const {
|
| 299 |
+
return a;
|
| 300 |
+
}
|
| 301 |
+
|
| 302 |
+
static C10_DEVICE acc_t translate_idx(acc_t acc, int64_t /*base_idx*/) {
|
| 303 |
+
return acc;
|
| 304 |
+
}
|
| 305 |
+
|
| 306 |
+
|
| 307 |
+
#if defined(__CUDACC__) || defined(__HIPCC__)
|
| 308 |
+
inline C10_DEVICE acc_t warp_shfl_down(acc_t acc, int offset) const {
|
| 309 |
+
return WARP_SHFL_DOWN(acc, offset);
|
| 310 |
+
}
|
| 311 |
+
#endif
|
| 312 |
+
};
|
| 313 |
+
|
| 314 |
+
// This accumulator template is used to calculate the order one norm of the
|
| 315 |
+
// absolute value of a set of numbers.
|
| 316 |
+
// `scalar_t` is the type of the input and `acc_t` is the type of the accumulated
|
| 317 |
+
// value. These types differ for complex number input support.
|
| 318 |
+
template <typename scalar_t, typename acc_t = scalar_t, typename out_t = acc_t>
|
| 319 |
+
struct NormOneOps {
|
| 320 |
+
inline C10_DEVICE acc_t reduce(acc_t acc, scalar_t data, int64_t /*idx*/) const {
|
| 321 |
+
return acc + static_cast<acc_t>(std::abs(data));
|
| 322 |
+
}
|
| 323 |
+
|
| 324 |
+
inline C10_DEVICE acc_t combine(acc_t a, acc_t b) const {
|
| 325 |
+
return a + b;
|
| 326 |
+
}
|
| 327 |
+
|
| 328 |
+
inline C10_DEVICE out_t project(acc_t a) const {
|
| 329 |
+
return a;
|
| 330 |
+
}
|
| 331 |
+
|
| 332 |
+
static C10_DEVICE acc_t translate_idx(acc_t acc, int64_t /*base_idx*/) {
|
| 333 |
+
return acc;
|
| 334 |
+
}
|
| 335 |
+
|
| 336 |
+
#if defined(__CUDACC__) || defined(__HIPCC__)
|
| 337 |
+
inline C10_DEVICE acc_t warp_shfl_down(acc_t acc, int offset) const {
|
| 338 |
+
return WARP_SHFL_DOWN(acc, offset);
|
| 339 |
+
}
|
| 340 |
+
#endif
|
| 341 |
+
};
|
| 342 |
+
|
| 343 |
+
|
| 344 |
+
template<typename acc_t>
|
| 345 |
+
struct AbsSwitch {};
|
| 346 |
+
|
| 347 |
+
template<typename scalar_t, typename acc_t>
|
| 348 |
+
inline C10_DEVICE acc_t abs_if_complex(scalar_t data, AbsSwitch<acc_t>) {
|
| 349 |
+
return static_cast<acc_t>(data);
|
| 350 |
+
}
|
| 351 |
+
|
| 352 |
+
template<typename scalar_t, typename acc_t>
|
| 353 |
+
inline C10_DEVICE acc_t abs_if_complex(std::complex<scalar_t> data, AbsSwitch<acc_t>) {
|
| 354 |
+
return static_cast<acc_t>(std::abs(data));
|
| 355 |
+
}
|
| 356 |
+
|
| 357 |
+
template<typename scalar_t, typename acc_t>
|
| 358 |
+
inline C10_DEVICE acc_t abs_if_complex(c10::complex<scalar_t> data, AbsSwitch<acc_t>) {
|
| 359 |
+
return static_cast<acc_t>(std::abs(data));
|
| 360 |
+
}
|
| 361 |
+
|
| 362 |
+
// This accumulator template is used to calculate the order two norm of the
|
| 363 |
+
// absolute value of a set of numbers.
|
| 364 |
+
// `scalar_t` is the type of the input and `acc_t` is the type of the accumulated
|
| 365 |
+
// value. These types differ for complex number input support.
|
| 366 |
+
template <typename scalar_t, typename acc_t = scalar_t, typename out_t = acc_t>
|
| 367 |
+
struct NormTwoOps {
|
| 368 |
+
inline C10_DEVICE acc_t reduce(acc_t acc, scalar_t data, int64_t /*idx*/) const {
|
| 369 |
+
acc_t data_ = abs_if_complex(data, AbsSwitch<acc_t>());
|
| 370 |
+
return acc + data_ * data_;
|
| 371 |
+
}
|
| 372 |
+
|
| 373 |
+
inline C10_DEVICE acc_t combine(acc_t a, acc_t b) const {
|
| 374 |
+
return a + b;
|
| 375 |
+
}
|
| 376 |
+
|
| 377 |
+
inline C10_DEVICE out_t project(acc_t a) const {
|
| 378 |
+
return device_sqrt(a);
|
| 379 |
+
}
|
| 380 |
+
|
| 381 |
+
static C10_DEVICE acc_t translate_idx(acc_t acc, int64_t /*base_idx*/) {
|
| 382 |
+
return acc;
|
| 383 |
+
}
|
| 384 |
+
|
| 385 |
+
#if defined(__CUDACC__) || defined(__HIPCC__)
|
| 386 |
+
inline C10_DEVICE acc_t warp_shfl_down(acc_t acc, int offset) const {
|
| 387 |
+
return WARP_SHFL_DOWN(acc, offset);
|
| 388 |
+
}
|
| 389 |
+
#endif
|
| 390 |
+
};
|
| 391 |
+
|
| 392 |
+
template <typename acc_t, typename data_t>
|
| 393 |
+
struct NanSumOps {
|
| 394 |
+
inline C10_DEVICE acc_t reduce(acc_t a, data_t b, int64_t /*idx*/) const {
|
| 395 |
+
return a + (at::_isnan(b) ? acc_t{0.} : acc_t{b});
|
| 396 |
+
}
|
| 397 |
+
|
| 398 |
+
inline C10_DEVICE acc_t combine(acc_t a, acc_t b) const {
|
| 399 |
+
return a + b;
|
| 400 |
+
}
|
| 401 |
+
|
| 402 |
+
inline C10_DEVICE data_t project(acc_t a) const {
|
| 403 |
+
return data_t{a};
|
| 404 |
+
}
|
| 405 |
+
|
| 406 |
+
static C10_DEVICE acc_t translate_idx(acc_t acc, int64_t /*base_idx*/) {
|
| 407 |
+
return acc;
|
| 408 |
+
}
|
| 409 |
+
|
| 410 |
+
#if defined(__CUDACC__) || defined(__HIPCC__)
|
| 411 |
+
inline C10_DEVICE acc_t warp_shfl_down(acc_t data, int offset) const {
|
| 412 |
+
return WARP_SHFL_DOWN(data, offset);
|
| 413 |
+
}
|
| 414 |
+
#endif
|
| 415 |
+
};
|
| 416 |
+
|
| 417 |
+
namespace detail {
|
| 418 |
+
|
| 419 |
+
template <typename scalar_t>
|
| 420 |
+
struct LessOrNan {
|
| 421 |
+
C10_DEVICE bool operator () (scalar_t a, scalar_t b, int64_t idx_a, int64_t idx_b) const {
|
| 422 |
+
// If (a == b), then choose the one with lower idx, else min(a, b)
|
| 423 |
+
if (at::_isnan(a)) {
|
| 424 |
+
if (at::_isnan(b)) {
|
| 425 |
+
return idx_a < idx_b;
|
| 426 |
+
}
|
| 427 |
+
return true;
|
| 428 |
+
}
|
| 429 |
+
return (a == b) ? idx_a < idx_b : (a < b);
|
| 430 |
+
}
|
| 431 |
+
};
|
| 432 |
+
|
| 433 |
+
template <typename scalar_t>
|
| 434 |
+
struct GreaterOrNan {
|
| 435 |
+
C10_DEVICE bool operator () (scalar_t a, scalar_t b, int64_t idx_a, int64_t idx_b) const {
|
| 436 |
+
// If (a == b), then choose the one with lower idx, else max(a, b)
|
| 437 |
+
if (at::_isnan(a)) {
|
| 438 |
+
if (at::_isnan(b)) {
|
| 439 |
+
return idx_a < idx_b;
|
| 440 |
+
}
|
| 441 |
+
return true;
|
| 442 |
+
}
|
| 443 |
+
return (a == b) ? idx_a < idx_b : (a > b);
|
| 444 |
+
}
|
| 445 |
+
};
|
| 446 |
+
|
| 447 |
+
template <typename comp_t>
|
| 448 |
+
struct MinMaxReductionOps {
|
| 449 |
+
using scalar_t = typename binary_function_traits<comp_t>::arg1_t;
|
| 450 |
+
using index_t = int64_t;
|
| 451 |
+
using arg_t = detail::pair<scalar_t, index_t>;
|
| 452 |
+
|
| 453 |
+
static C10_DEVICE arg_t project(arg_t arg) {
|
| 454 |
+
return arg;
|
| 455 |
+
}
|
| 456 |
+
|
| 457 |
+
static C10_DEVICE arg_t reduce(arg_t arg, scalar_t val, int64_t idx) {
|
| 458 |
+
return comp_t{}(arg.first, val, arg.second, idx) ? arg : arg_t(val, idx);
|
| 459 |
+
}
|
| 460 |
+
|
| 461 |
+
static C10_DEVICE arg_t combine(arg_t a, arg_t b) {
|
| 462 |
+
return comp_t{}(a.first, b.first, a.second, b.second) ? a : b;
|
| 463 |
+
}
|
| 464 |
+
|
| 465 |
+
static C10_DEVICE arg_t translate_idx(arg_t a, int64_t base_idx) {
|
| 466 |
+
return {a.first, a.second + base_idx};
|
| 467 |
+
}
|
| 468 |
+
|
| 469 |
+
#if defined(__CUDACC__) || defined(__HIPCC__)
|
| 470 |
+
static C10_DEVICE arg_t warp_shfl_down(arg_t arg, int offset) {
|
| 471 |
+
return arg_t(WARP_SHFL_DOWN(arg.first, offset),
|
| 472 |
+
WARP_SHFL_DOWN(arg.second, offset));
|
| 473 |
+
}
|
| 474 |
+
#endif
|
| 475 |
+
};
|
| 476 |
+
|
| 477 |
+
template <typename comp_t>
|
| 478 |
+
struct ArgReductionOps : public MinMaxReductionOps<comp_t> {
|
| 479 |
+
using typename MinMaxReductionOps<comp_t>::scalar_t;
|
| 480 |
+
using typename MinMaxReductionOps<comp_t>::index_t;
|
| 481 |
+
using typename MinMaxReductionOps<comp_t>::arg_t;
|
| 482 |
+
|
| 483 |
+
static C10_DEVICE index_t project(arg_t arg) {
|
| 484 |
+
return arg.second;
|
| 485 |
+
}
|
| 486 |
+
};
|
| 487 |
+
|
| 488 |
+
} // namespace detail
|
| 489 |
+
|
| 490 |
+
template <typename scalar_t>
|
| 491 |
+
struct ArgMaxOps :
|
| 492 |
+
public detail::ArgReductionOps<detail::GreaterOrNan<scalar_t>> {
|
| 493 |
+
};
|
| 494 |
+
|
| 495 |
+
template <typename scalar_t>
|
| 496 |
+
struct ArgMinOps :
|
| 497 |
+
public detail::ArgReductionOps<detail::LessOrNan<scalar_t>> {
|
| 498 |
+
};
|
| 499 |
+
|
| 500 |
+
template <typename scalar_t>
|
| 501 |
+
struct MinOps :
|
| 502 |
+
public detail::MinMaxReductionOps<detail::LessOrNan<scalar_t>> {
|
| 503 |
+
};
|
| 504 |
+
|
| 505 |
+
template <typename scalar_t>
|
| 506 |
+
struct MaxOps :
|
| 507 |
+
public detail::MinMaxReductionOps<detail::GreaterOrNan<scalar_t>> {
|
| 508 |
+
};
|
| 509 |
+
|
| 510 |
+
template <typename scalar_t, typename acc_scalar_t, typename index_t>
|
| 511 |
+
struct MinMaxOps {
|
| 512 |
+
using acc_t = detail::pair<acc_scalar_t, acc_scalar_t>;
|
| 513 |
+
inline C10_DEVICE acc_t reduce(acc_t acc, scalar_t data, index_t /*idx*/) const {
|
| 514 |
+
return combine(acc, {data, data});
|
| 515 |
+
}
|
| 516 |
+
|
| 517 |
+
inline C10_DEVICE acc_t combine(acc_t a, acc_t b) const {
|
| 518 |
+
auto min_val = (at::_isnan(a.first) || a.first < b.first) ? a.first : b.first;
|
| 519 |
+
auto max_val = (at::_isnan(a.second) || a.second > b.second) ? a.second : b.second;
|
| 520 |
+
|
| 521 |
+
return {min_val, max_val};
|
| 522 |
+
}
|
| 523 |
+
|
| 524 |
+
inline C10_DEVICE acc_t project(acc_t acc) const {
|
| 525 |
+
return acc;
|
| 526 |
+
}
|
| 527 |
+
|
| 528 |
+
static C10_DEVICE acc_t translate_idx(acc_t acc, int64_t /*base_idx*/) {
|
| 529 |
+
return acc;
|
| 530 |
+
}
|
| 531 |
+
|
| 532 |
+
#if defined(__CUDACC__) || defined(__HIPCC__)
|
| 533 |
+
inline C10_DEVICE acc_t warp_shfl_down(acc_t acc, int offset) const {
|
| 534 |
+
return {
|
| 535 |
+
WARP_SHFL_DOWN(acc.first, offset), WARP_SHFL_DOWN(acc.second, offset)
|
| 536 |
+
};
|
| 537 |
+
}
|
| 538 |
+
#endif
|
| 539 |
+
};
|
| 540 |
+
|
| 541 |
+
}} // namespace at::native
|
| 542 |
+
|
| 543 |
+
#undef MAX
|
| 544 |
+
#undef MIN
|
.venv/lib/python3.11/site-packages/torch/include/ATen/native/Sorting.h
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/native/DispatchStub.h>
|
| 4 |
+
#include <cstdint>
|
| 5 |
+
|
| 6 |
+
namespace at {
|
| 7 |
+
class TensorBase;
|
| 8 |
+
}
|
| 9 |
+
|
| 10 |
+
namespace at::native {
|
| 11 |
+
|
| 12 |
+
enum class QUANTILE_INTERPOLATION_MODE : uint8_t {
|
| 13 |
+
LINEAR,
|
| 14 |
+
LOWER,
|
| 15 |
+
HIGHER,
|
| 16 |
+
MIDPOINT,
|
| 17 |
+
NEAREST
|
| 18 |
+
};
|
| 19 |
+
|
| 20 |
+
using sort_fn = void(*)(const TensorBase&, const TensorBase&, const TensorBase&, int64_t, bool, bool);
|
| 21 |
+
using topk_fn = void(*)(const TensorBase&, const TensorBase&, const TensorBase&, int64_t, int64_t, bool, bool);
|
| 22 |
+
|
| 23 |
+
DECLARE_DISPATCH(sort_fn, sort_stub);
|
| 24 |
+
DECLARE_DISPATCH(topk_fn, topk_stub);
|
| 25 |
+
|
| 26 |
+
void _fill_indices(const TensorBase &indices, int64_t dim);
|
| 27 |
+
|
| 28 |
+
} // namespace at::native
|
.venv/lib/python3.11/site-packages/torch/include/ATen/native/SortingUtils.h
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/NumericUtils.h>
|
| 4 |
+
#include <ATen/native/Resize.h>
|
| 5 |
+
#include <c10/util/irange.h>
|
| 6 |
+
|
| 7 |
+
#ifndef AT_PER_OPERATOR_HEADERS
|
| 8 |
+
#include <ATen/Functions.h>
|
| 9 |
+
#else
|
| 10 |
+
#include <ATen/ops/empty.h>
|
| 11 |
+
#endif
|
| 12 |
+
|
| 13 |
+
namespace at::native {
|
| 14 |
+
|
| 15 |
+
// ensure we get good values and indices for kthvalue, mode
|
| 16 |
+
// this will always be with the reducing dim as 1-d
|
| 17 |
+
inline void _reduction_with_indices_allocate_or_resize_output(
|
| 18 |
+
Tensor& values,
|
| 19 |
+
Tensor& indices,
|
| 20 |
+
const Tensor& self,
|
| 21 |
+
int64_t dim_,
|
| 22 |
+
bool keepdim) {
|
| 23 |
+
int64_t dim = maybe_wrap_dim(dim_, self.dim(), /*wrap_scalar=*/true);
|
| 24 |
+
auto result_sizes = self.sizes().vec();
|
| 25 |
+
if (!result_sizes.empty()) {
|
| 26 |
+
result_sizes[dim] = 1;
|
| 27 |
+
}
|
| 28 |
+
if (values.defined()) {
|
| 29 |
+
TORCH_CHECK(
|
| 30 |
+
self.options().type_equal(values.options()),
|
| 31 |
+
"output values must be of same type as input");
|
| 32 |
+
if (!keepdim && values.dim() == self.dim() - 1) {
|
| 33 |
+
// unsqueeze to preserve passed in noncontiguous tensor in resize
|
| 34 |
+
values.unsqueeze_(dim);
|
| 35 |
+
}
|
| 36 |
+
resize_output(values, result_sizes);
|
| 37 |
+
} else {
|
| 38 |
+
values = at::empty(result_sizes, self.options());
|
| 39 |
+
}
|
| 40 |
+
if (indices.defined()) {
|
| 41 |
+
TORCH_CHECK(
|
| 42 |
+
indices.dtype() == kLong, "output indices must be of scalar type Long");
|
| 43 |
+
TORCH_CHECK(
|
| 44 |
+
indices.device() == self.device(),
|
| 45 |
+
"output indices must be on same device as input");
|
| 46 |
+
if (!keepdim && indices.dim() == self.dim() - 1) {
|
| 47 |
+
// unsqueeze to preserve passed in noncontiguous tensor in resize
|
| 48 |
+
indices.unsqueeze_(dim);
|
| 49 |
+
}
|
| 50 |
+
resize_output(indices, result_sizes);
|
| 51 |
+
} else {
|
| 52 |
+
indices = at::empty(result_sizes, self.options().dtype(kLong));
|
| 53 |
+
}
|
| 54 |
+
}
|
| 55 |
+
|
| 56 |
+
// ensure we get good values and indices for topk
|
| 57 |
+
inline void _allocate_or_resize_output_with_indices(
|
| 58 |
+
Tensor& values,
|
| 59 |
+
Tensor& indices,
|
| 60 |
+
const Tensor& self,
|
| 61 |
+
int64_t dim_,
|
| 62 |
+
int64_t k) {
|
| 63 |
+
int64_t dim = maybe_wrap_dim(dim_, self.dim(), /*wrap_scalar=*/true);
|
| 64 |
+
auto result_sizes = self.sizes().vec();
|
| 65 |
+
if (!result_sizes.empty()) {
|
| 66 |
+
result_sizes[dim] = k;
|
| 67 |
+
}
|
| 68 |
+
if (values.defined()) {
|
| 69 |
+
TORCH_CHECK(
|
| 70 |
+
self.options().type_equal(values.options()),
|
| 71 |
+
"output values must be of same type as input");
|
| 72 |
+
values.resize_(result_sizes);
|
| 73 |
+
} else {
|
| 74 |
+
values = at::empty(result_sizes, self.options());
|
| 75 |
+
}
|
| 76 |
+
if (indices.defined()) {
|
| 77 |
+
TORCH_CHECK(
|
| 78 |
+
indices.dtype() == kLong, "output indices must be of scalar type Long");
|
| 79 |
+
TORCH_CHECK(
|
| 80 |
+
indices.device() == self.device(),
|
| 81 |
+
"output indices must be on same device as input");
|
| 82 |
+
indices.resize_(result_sizes);
|
| 83 |
+
} else {
|
| 84 |
+
indices = at::empty(result_sizes, self.options().dtype(kLong));
|
| 85 |
+
}
|
| 86 |
+
}
|
| 87 |
+
|
| 88 |
+
} // namespace at::native
|
.venv/lib/python3.11/site-packages/torch/include/ATen/native/SparseTensorUtils.h
ADDED
|
@@ -0,0 +1,190 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/Parallel.h>
|
| 4 |
+
#include <ATen/SparseTensorImpl.h>
|
| 5 |
+
#include <ATen/core/Tensor.h>
|
| 6 |
+
|
| 7 |
+
#ifndef AT_PER_OPERATOR_HEADERS
|
| 8 |
+
#include <ATen/Functions.h>
|
| 9 |
+
#else
|
| 10 |
+
#include <ATen/ops/empty.h>
|
| 11 |
+
#include <ATen/ops/tensor.h>
|
| 12 |
+
#endif
|
| 13 |
+
|
| 14 |
+
namespace at::sparse {
|
| 15 |
+
|
| 16 |
+
// Just for documentary purposes
|
| 17 |
+
using SparseTensor = Tensor;
|
| 18 |
+
using SparseType = Type;
|
| 19 |
+
|
| 20 |
+
// This is an internal utility function for getting at the SparseTensorImpl,
|
| 21 |
+
// so that we can write sparse tensor specific accessors for special fields
|
| 22 |
+
// in SparseTensor. You should only use this for writing low level
|
| 23 |
+
// setters/getters for SparseTensorImpl fields; otherwise, you should use
|
| 24 |
+
// the low level setters/getters that were implemented using this.
|
| 25 |
+
//
|
| 26 |
+
// This may be called repeatedly, so make sure it's pretty cheap.
|
| 27 |
+
inline SparseTensorImpl* get_sparse_impl(const SparseTensor& self) {
|
| 28 |
+
TORCH_INTERNAL_ASSERT(
|
| 29 |
+
self.is_sparse(), "_internal_get_SparseTensorImpl: not a sparse tensor");
|
| 30 |
+
return static_cast<SparseTensorImpl*>(self.unsafeGetTensorImpl());
|
| 31 |
+
}
|
| 32 |
+
|
| 33 |
+
// Takes indices and values and directly puts them into the sparse tensor, no
|
| 34 |
+
// copy. This used to be called THSTensor_(_move)
|
| 35 |
+
inline void alias_into_sparse(
|
| 36 |
+
const SparseTensor& self,
|
| 37 |
+
const Tensor& indices,
|
| 38 |
+
const Tensor& values) {
|
| 39 |
+
get_sparse_impl(self)->set_indices_and_values_unsafe(indices, values);
|
| 40 |
+
}
|
| 41 |
+
|
| 42 |
+
// Take indices and values and makes a (data) copy of them to put into the
|
| 43 |
+
// sparse indices/values. This used to be called THSTensor_(_set)
|
| 44 |
+
inline void copy_into_sparse(
|
| 45 |
+
const SparseTensor& self,
|
| 46 |
+
const Tensor& indices,
|
| 47 |
+
const Tensor& values,
|
| 48 |
+
bool non_blocking) {
|
| 49 |
+
alias_into_sparse(
|
| 50 |
+
self,
|
| 51 |
+
indices.to(self._indices().options(), non_blocking, /*copy=*/true),
|
| 52 |
+
values.to(self._values().options(), non_blocking, /*copy=*/true));
|
| 53 |
+
}
|
| 54 |
+
|
| 55 |
+
// TODO: put this into the public API
|
| 56 |
+
inline bool is_same_tensor(const Tensor& lhs, const Tensor& rhs) {
|
| 57 |
+
return lhs.unsafeGetTensorImpl() == rhs.unsafeGetTensorImpl();
|
| 58 |
+
}
|
| 59 |
+
|
| 60 |
+
inline bool is_same_density(const SparseTensor& self, const SparseTensor& src) {
|
| 61 |
+
return self.sparse_dim() == src.sparse_dim() &&
|
| 62 |
+
self.dense_dim() == src.dense_dim();
|
| 63 |
+
}
|
| 64 |
+
|
| 65 |
+
// Give us a new values tensor, with the same dimensionality
|
| 66 |
+
// as 'values' but with a new number of non-zero elements.
|
| 67 |
+
// TODO: Expose this for real in ATen, some day?
|
| 68 |
+
// NB: Doesn't preserve data.
|
| 69 |
+
inline Tensor new_values_with_size_of(const Tensor& values, int64_t nnz) {
|
| 70 |
+
std::vector<int64_t> size = values.sizes().vec();
|
| 71 |
+
size[0] = nnz;
|
| 72 |
+
return at::empty(size, values.options());
|
| 73 |
+
}
|
| 74 |
+
|
| 75 |
+
// NOTE [ Flatten Sparse Indices ]
|
| 76 |
+
// This helper function flattens a sparse indices tensor (a Tensor) into a 1D
|
| 77 |
+
// indices tensor. E.g.,
|
| 78 |
+
// input = [[2, 4, 0],
|
| 79 |
+
// [3, 1, 10]]
|
| 80 |
+
// full_size = [2, 12]
|
| 81 |
+
// output = [ 2 * 12 + 3, 4 * 12 + 1, 0 * 12 + 10 ] = [27, 49, 10]
|
| 82 |
+
//
|
| 83 |
+
// In other words, assuming that each `indices[i, :]` is a valid index to a
|
| 84 |
+
// tensor `t` of shape `full_size`. This returns the corresponding indices to
|
| 85 |
+
// the flattened tensor `t.reshape( prod(full_size[:indices.size(0)]), -1 )`.
|
| 86 |
+
// if forceClone is true, the result will forced to be a clone of self.
|
| 87 |
+
// if force_clone is true, the result will forced to be a clone of self.
|
| 88 |
+
TORCH_API Tensor flatten_indices(
|
| 89 |
+
const Tensor& indices,
|
| 90 |
+
IntArrayRef full_size,
|
| 91 |
+
bool force_clone = false);
|
| 92 |
+
|
| 93 |
+
// Flatten sparse tensor's indices from nD to 1D, similar to NOTE [ Flatten
|
| 94 |
+
// Sparse Indices ], except this one allows partial flatten: only flatten on
|
| 95 |
+
// specified dims. Note that the flatten indices might be uncoalesced if
|
| 96 |
+
// dims_to_flatten.size() < sparse_dim. Also if input indices is already
|
| 97 |
+
// coalesced, the flattened indices will also be sorted.
|
| 98 |
+
//
|
| 99 |
+
// args:
|
| 100 |
+
// indices: sparse tensor indices
|
| 101 |
+
// sizes: sparse tensor sizes
|
| 102 |
+
// dims_to_flatten: a list of dim index to flatten
|
| 103 |
+
//
|
| 104 |
+
// Ex1:
|
| 105 |
+
// indices = [[2, 4, 0],
|
| 106 |
+
// [3, 1, 3]]
|
| 107 |
+
// sizes = [2, 12]
|
| 108 |
+
// dims_to_flatten = [0, 1]
|
| 109 |
+
// new_indices = [ 2 * 12 + 3, 4 * 12 + 1, 0 * 12 + 3 ] = [27, 49, 3]
|
| 110 |
+
//
|
| 111 |
+
// Ex2:
|
| 112 |
+
// dims_to_flatten = [1]
|
| 113 |
+
// new_indices = [ 3, 1, 3 ] # uncoalesced
|
| 114 |
+
TORCH_API Tensor flatten_indices_by_dims(
|
| 115 |
+
const Tensor& indices,
|
| 116 |
+
const IntArrayRef& sizes,
|
| 117 |
+
const IntArrayRef& dims_to_flatten);
|
| 118 |
+
|
| 119 |
+
// Find the CSR representation for a row `indices` from the COO format
|
| 120 |
+
TORCH_API Tensor coo_to_csr(const int64_t* indices, int64_t dim, int64_t nnz);
|
| 121 |
+
|
| 122 |
+
TORCH_API Tensor zeros_like_with_indices(const Tensor& t);
|
| 123 |
+
|
| 124 |
+
template <size_t static_shape_max_len>
|
| 125 |
+
class TensorGeometryHolder {
|
| 126 |
+
using geometry_holder_t = std::array<int64_t, static_shape_max_len>;
|
| 127 |
+
|
| 128 |
+
public:
|
| 129 |
+
explicit TensorGeometryHolder(
|
| 130 |
+
IntArrayRef sizes,
|
| 131 |
+
IntArrayRef strides,
|
| 132 |
+
TensorOptions options = {}) {
|
| 133 |
+
std::copy(sizes.begin(), sizes.end(), t_sizes.begin());
|
| 134 |
+
std::copy(strides.begin(), strides.end(), t_strides.begin());
|
| 135 |
+
}
|
| 136 |
+
|
| 137 |
+
explicit TensorGeometryHolder(const Tensor& t)
|
| 138 |
+
: TensorGeometryHolder(t.sizes(), t.strides()) {}
|
| 139 |
+
|
| 140 |
+
auto operator*() const {
|
| 141 |
+
return std::make_tuple(t_sizes, t_strides);
|
| 142 |
+
}
|
| 143 |
+
|
| 144 |
+
private:
|
| 145 |
+
geometry_holder_t t_sizes;
|
| 146 |
+
geometry_holder_t t_strides;
|
| 147 |
+
};
|
| 148 |
+
|
| 149 |
+
template <>
|
| 150 |
+
class TensorGeometryHolder<0> {
|
| 151 |
+
using geometry_holder_t = Tensor;
|
| 152 |
+
|
| 153 |
+
public:
|
| 154 |
+
explicit TensorGeometryHolder(
|
| 155 |
+
IntArrayRef sizes,
|
| 156 |
+
IntArrayRef strides,
|
| 157 |
+
TensorOptions options) {
|
| 158 |
+
const int64_t t_ndims = sizes.size();
|
| 159 |
+
const auto cpu_options = TensorOptions(options).dtype(kLong).device(kCPU);
|
| 160 |
+
Tensor t_sizes_and_strides_cpu = at::empty({2, t_ndims}, cpu_options);
|
| 161 |
+
t_sizes_and_strides_cpu.select(0, 0).copy_(at::tensor(sizes, cpu_options));
|
| 162 |
+
t_sizes_and_strides_cpu.select(0, 1).copy_(
|
| 163 |
+
at::tensor(strides, cpu_options));
|
| 164 |
+
const Tensor t_sizes_and_strides =
|
| 165 |
+
t_sizes_and_strides_cpu.to(options.device());
|
| 166 |
+
t_sizes = t_sizes_and_strides.select(0, 0);
|
| 167 |
+
t_strides = t_sizes_and_strides.select(0, 1);
|
| 168 |
+
}
|
| 169 |
+
|
| 170 |
+
explicit TensorGeometryHolder(const Tensor& t)
|
| 171 |
+
: TensorGeometryHolder(t.sizes(), t.strides(), t.options()) {}
|
| 172 |
+
|
| 173 |
+
auto operator*() const {
|
| 174 |
+
return std::make_tuple(
|
| 175 |
+
t_sizes.template data_ptr<int64_t>(),
|
| 176 |
+
t_strides.template data_ptr<int64_t>());
|
| 177 |
+
}
|
| 178 |
+
|
| 179 |
+
private:
|
| 180 |
+
geometry_holder_t t_sizes;
|
| 181 |
+
geometry_holder_t t_strides;
|
| 182 |
+
};
|
| 183 |
+
|
| 184 |
+
// Return all indices of a tensor with the given shape.
|
| 185 |
+
//
|
| 186 |
+
// full_coo_indices(shape) is equivalent to
|
| 187 |
+
// torch.ones(shape).nonzero().transpose(-2, -1) but much faster.
|
| 188 |
+
TORCH_API Tensor full_coo_indices(IntArrayRef sizes, TensorOptions options);
|
| 189 |
+
|
| 190 |
+
} // namespace at::sparse
|
.venv/lib/python3.11/site-packages/torch/include/ATen/native/TensorAdvancedIndexing.h
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
// Indexing tensors by tensors
|
| 4 |
+
|
| 5 |
+
#include <ATen/core/List.h>
|
| 6 |
+
#include <ATen/core/Tensor.h>
|
| 7 |
+
#include <ATen/native/DispatchStub.h>
|
| 8 |
+
#include <ATen/native/ReductionType.h>
|
| 9 |
+
|
| 10 |
+
namespace at {
|
| 11 |
+
struct TensorIterator;
|
| 12 |
+
}
|
| 13 |
+
|
| 14 |
+
namespace at::native {
|
| 15 |
+
|
| 16 |
+
using index_put_with_sort_fn = void(*)(Tensor &, const c10::List<std::optional<Tensor>> &, const Tensor &, bool accumulate, bool unsafe);
|
| 17 |
+
using index_put_with_sort_quantized_fn = void(*)(Tensor& self, const c10::List<std::optional<Tensor>>& indices, const Tensor& value, double scale, int zero_point, bool unsafe);
|
| 18 |
+
using gather_fn = void (*)(const Tensor & result, const Tensor & self, int64_t dim, const Tensor & index);
|
| 19 |
+
using scatter_fn = void(*)(const Tensor& self, int64_t dim, const Tensor& index, const Tensor& src);
|
| 20 |
+
using scatter_fill_fn = void(*)(const Tensor& self, int64_t dim, const Tensor& index, const Scalar& src);
|
| 21 |
+
using scatter_add_fn = void(*)(const Tensor& self, int64_t dim, const Tensor& index, const Tensor& src);
|
| 22 |
+
using scatter_reduce_fn = void(*)(const Tensor& self, const int64_t dim, const Tensor& index,
|
| 23 |
+
const Tensor& src, const ReductionType& reduce);
|
| 24 |
+
using scatter_scalar_reduce_fn = void(*)(const Tensor& self, const int64_t dim, const Tensor& index,
|
| 25 |
+
const Scalar& value, const ReductionType& reduce);
|
| 26 |
+
using scatter_reduce_two_fn = void(*)(const Tensor& self, const int64_t dim, const Tensor& index,
|
| 27 |
+
const Tensor& src, const ReductionType& reduce);
|
| 28 |
+
|
| 29 |
+
DECLARE_DISPATCH(index_put_with_sort_fn, index_put_with_sort_stub);
|
| 30 |
+
DECLARE_DISPATCH(index_put_with_sort_quantized_fn, index_put_with_sort_quantized_stub);
|
| 31 |
+
DECLARE_DISPATCH(gather_fn, gather_stub);
|
| 32 |
+
DECLARE_DISPATCH(scatter_fn, scatter_stub);
|
| 33 |
+
DECLARE_DISPATCH(scatter_fill_fn, scatter_fill_stub);
|
| 34 |
+
DECLARE_DISPATCH(scatter_add_fn, scatter_add_stub);
|
| 35 |
+
DECLARE_DISPATCH(scatter_reduce_fn, scatter_reduce_stub);
|
| 36 |
+
DECLARE_DISPATCH(scatter_scalar_reduce_fn, scatter_scalar_reduce_stub);
|
| 37 |
+
DECLARE_DISPATCH(scatter_reduce_two_fn, scatter_reduce_two_stub);
|
| 38 |
+
|
| 39 |
+
TORCH_API Tensor& index_out(Tensor& result, const Tensor & self, const c10::List<std::optional<at::Tensor>>& indices);
|
| 40 |
+
|
| 41 |
+
using scatter_add_expanded_index_fn = void(*)(const Tensor&, const Tensor&, const Tensor&);
|
| 42 |
+
using scatter_reduce_expanded_index_fn = void(*)(const Tensor&, const Tensor&, const Tensor&, const ReductionType& reduce, bool);
|
| 43 |
+
using gather_expanded_index_fn = void (*)(const Tensor&, const Tensor&, const Tensor&);
|
| 44 |
+
|
| 45 |
+
DECLARE_DISPATCH(scatter_add_expanded_index_fn, scatter_add_expanded_index_stub);
|
| 46 |
+
DECLARE_DISPATCH(scatter_reduce_expanded_index_fn, scatter_reduce_expanded_index_stub);
|
| 47 |
+
DECLARE_DISPATCH(gather_expanded_index_fn, gather_expanded_index_stub);
|
| 48 |
+
|
| 49 |
+
} // namespace at::native
|
.venv/lib/python3.11/site-packages/torch/include/ATen/native/TensorAdvancedIndexingUtils.h
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
#include <ATen/core/Tensor.h>
|
| 3 |
+
#include <ATen/native/IndexingUtils.h>
|
| 4 |
+
#include <ATen/native/TensorIterator.h>
|
| 5 |
+
|
| 6 |
+
namespace at::native {
|
| 7 |
+
namespace {
|
| 8 |
+
#ifndef STRIP_ERROR_MESSAGES
|
| 9 |
+
inline std::string shapes_as_str(TensorList tensors) {
|
| 10 |
+
std::ostringstream os;
|
| 11 |
+
bool first = true;
|
| 12 |
+
for (auto& tensor : tensors) {
|
| 13 |
+
if (tensor.defined()) {
|
| 14 |
+
if (!first) {
|
| 15 |
+
os << ", ";
|
| 16 |
+
}
|
| 17 |
+
os << tensor.sizes();
|
| 18 |
+
first = false;
|
| 19 |
+
}
|
| 20 |
+
}
|
| 21 |
+
return os.str();
|
| 22 |
+
}
|
| 23 |
+
#endif
|
| 24 |
+
} // anonymous namespace
|
| 25 |
+
|
| 26 |
+
inline std::tuple<bool, Tensor> canDispatchToMaskedFill(const Tensor& self, const torch::List<std::optional<at::Tensor>>& indices,
|
| 27 |
+
const Tensor& value){
|
| 28 |
+
if (!(value.numel() ==1 && value.device().is_cpu())){
|
| 29 |
+
return std::make_tuple(false,Tensor());
|
| 30 |
+
}
|
| 31 |
+
int64_t num_ind = 0;
|
| 32 |
+
Tensor mask;
|
| 33 |
+
auto self_device = self.device();
|
| 34 |
+
for (const std::optional<Tensor>& i: indices) {
|
| 35 |
+
if (!i.has_value() || !(*i).defined()){
|
| 36 |
+
num_ind++;
|
| 37 |
+
} else {
|
| 38 |
+
const Tensor &index = *i;
|
| 39 |
+
if ((index.scalar_type() != kByte && index.scalar_type() != kBool) ||
|
| 40 |
+
index.device() != self_device || mask.defined()){
|
| 41 |
+
return std::make_tuple(false, Tensor());
|
| 42 |
+
} else {
|
| 43 |
+
mask = index;
|
| 44 |
+
for (const auto j : c10::irange(index.dim())) {
|
| 45 |
+
int64_t srcIdx = num_ind + j;
|
| 46 |
+
TORCH_CHECK_INDEX(index.size(j) == self.size(srcIdx), "The shape of the mask ", index.sizes(), " at index ", j,
|
| 47 |
+
" does not match the shape of the indexed tensor ", self.sizes(), " at index ", srcIdx);
|
| 48 |
+
}
|
| 49 |
+
num_ind += mask.ndimension();
|
| 50 |
+
}
|
| 51 |
+
}
|
| 52 |
+
}
|
| 53 |
+
for (C10_UNUSED const auto i : c10::irange(num_ind, self.ndimension())) {
|
| 54 |
+
mask = mask.unsqueeze(-1);
|
| 55 |
+
}
|
| 56 |
+
return std::make_tuple(true, mask);
|
| 57 |
+
}
|
| 58 |
+
|
| 59 |
+
inline AdvancedIndex make_info(Tensor self, IOptTensorListRef orig) {
|
| 60 |
+
checkIndexTensorTypes(orig, /*allow_int*/ true);
|
| 61 |
+
// first expand BoolTensor (masks) or ByteTensor (masks) into 1 or more LongTensors
|
| 62 |
+
auto indices = expandTensors(self, orig);
|
| 63 |
+
// next broadcast all index tensors together
|
| 64 |
+
try {
|
| 65 |
+
indices = expand_outplace(indices);
|
| 66 |
+
} catch (std::exception& e) {
|
| 67 |
+
TORCH_CHECK_INDEX(false, "shape mismatch: indexing tensors could not be broadcast together"
|
| 68 |
+
" with shapes ", shapes_as_str(indices));
|
| 69 |
+
}
|
| 70 |
+
// add missing null Tensors so that it matches self.dim()
|
| 71 |
+
while (indices.size() < (size_t)self.dim()) {
|
| 72 |
+
indices.emplace_back();
|
| 73 |
+
}
|
| 74 |
+
// if the non-null indices are not all adjacent, transpose self and indices
|
| 75 |
+
// together so that they're adjacent at the front
|
| 76 |
+
if (!hasContiguousSubspace(indices)) {
|
| 77 |
+
std::tie(self, indices) = transposeToFront(self, indices);
|
| 78 |
+
}
|
| 79 |
+
// Ensure indices are on the same device as self
|
| 80 |
+
for (auto & indice : indices) {
|
| 81 |
+
if (indice.defined() && indice.device() != self.device()) {
|
| 82 |
+
indice = indice.to(self.device());
|
| 83 |
+
}
|
| 84 |
+
}
|
| 85 |
+
for (auto & indice : indices) {
|
| 86 |
+
if (indice.defined() && indice.dtype() == at::kInt) {
|
| 87 |
+
indice = indice.to(at::kLong);
|
| 88 |
+
}
|
| 89 |
+
}
|
| 90 |
+
|
| 91 |
+
return AdvancedIndex(self, indices);
|
| 92 |
+
}
|
| 93 |
+
|
| 94 |
+
} // namespace at::native
|
.venv/lib/python3.11/site-packages/torch/include/ATen/native/TensorDimApply.h
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
#include <ATen/core/Tensor.h>
|
| 3 |
+
#include <c10/util/irange.h>
|
| 4 |
+
|
| 5 |
+
namespace at::native {
|
| 6 |
+
//input tensors are non-zero dim and non-empty
|
| 7 |
+
template<typename T1, typename T2, typename Function>
|
| 8 |
+
|
| 9 |
+
void tensor_dim_apply3(const Tensor& self, Tensor& values, Tensor& indices, int64_t dim, Function func) {
|
| 10 |
+
int ndims = self.dim();
|
| 11 |
+
int tensor_dim_apply_has_finished = 0;
|
| 12 |
+
std::vector<int64_t> counter(ndims, 0);
|
| 13 |
+
const T1* self_data = self.const_data_ptr<T1>();
|
| 14 |
+
T1* values_data = values.data_ptr<T1>();
|
| 15 |
+
T2* indices_data = indices.data_ptr<T2>();
|
| 16 |
+
int64_t self_stride = self.stride(dim);
|
| 17 |
+
int64_t values_stride = values.stride(dim);
|
| 18 |
+
int64_t indices_stride = indices.stride(dim);
|
| 19 |
+
int self_dim_size = self.size(dim);
|
| 20 |
+
|
| 21 |
+
while (!tensor_dim_apply_has_finished) {
|
| 22 |
+
func(self_data, values_data, indices_data, self_dim_size, self_stride, values_stride, indices_stride);
|
| 23 |
+
if (ndims == 1) {
|
| 24 |
+
break;
|
| 25 |
+
}
|
| 26 |
+
for (const auto dim_i : c10::irange(ndims)) {
|
| 27 |
+
if (dim_i == dim) {
|
| 28 |
+
if (dim_i == (ndims - 1)) {
|
| 29 |
+
tensor_dim_apply_has_finished = 1;
|
| 30 |
+
break;
|
| 31 |
+
}
|
| 32 |
+
continue;
|
| 33 |
+
}
|
| 34 |
+
counter[dim_i]++;
|
| 35 |
+
self_data += self.stride(dim_i);
|
| 36 |
+
values_data += values.stride(dim_i);
|
| 37 |
+
indices_data += indices.stride(dim_i);
|
| 38 |
+
|
| 39 |
+
if (counter[dim_i] == self.size(dim_i)) {
|
| 40 |
+
if (dim_i == ndims-1) {
|
| 41 |
+
tensor_dim_apply_has_finished = 1;
|
| 42 |
+
break;
|
| 43 |
+
} else {
|
| 44 |
+
self_data -= counter[dim_i]*self.stride(dim_i);
|
| 45 |
+
values_data -= counter[dim_i]*values.stride(dim_i);
|
| 46 |
+
indices_data -= counter[dim_i]*indices.stride(dim_i);
|
| 47 |
+
counter[dim_i] = 0;
|
| 48 |
+
}
|
| 49 |
+
} else {
|
| 50 |
+
break;
|
| 51 |
+
}
|
| 52 |
+
}
|
| 53 |
+
}
|
| 54 |
+
}
|
| 55 |
+
} // namespace at::native
|
.venv/lib/python3.11/site-packages/torch/include/ATen/native/TensorFactories.h
ADDED
|
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/core/Tensor.h>
|
| 4 |
+
#include <ATen/EmptyTensor.h>
|
| 5 |
+
#include <ATen/TensorIterator.h>
|
| 6 |
+
#include <ATen/Dispatch.h>
|
| 7 |
+
#include <ATen/Dispatch_v2.h>
|
| 8 |
+
#include <ATen/native/DispatchStub.h>
|
| 9 |
+
|
| 10 |
+
#ifndef AT_PER_OPERATOR_HEADERS
|
| 11 |
+
#include <ATen/Functions.h>
|
| 12 |
+
#else
|
| 13 |
+
#include <ATen/ops/scalar_tensor.h>
|
| 14 |
+
#endif
|
| 15 |
+
|
| 16 |
+
namespace at::native {
|
| 17 |
+
// Different combinations of row, col, and offset can lead to two cases:
|
| 18 |
+
//
|
| 19 |
+
// Case 1 - Trapezoid (Triangle as a special case): row + offset <= col
|
| 20 |
+
// Example A: offset > 0
|
| 21 |
+
// 1 1 0 0 0
|
| 22 |
+
// 1 1 1 0 0
|
| 23 |
+
// 1 1 1 1 0
|
| 24 |
+
// Example B: offset <= 0
|
| 25 |
+
// 0 0 0
|
| 26 |
+
// 1 0 0
|
| 27 |
+
// 1 1 0
|
| 28 |
+
// In this case, we calculate the number of elements in the first row and
|
| 29 |
+
// last row of the tril respectively, and then compute the tril size.
|
| 30 |
+
//
|
| 31 |
+
// Case 2 - Trapezoid + Rectangle: row + offset > col
|
| 32 |
+
// Example:
|
| 33 |
+
// 1 1 0
|
| 34 |
+
// 1 1 1
|
| 35 |
+
// 1 1 1
|
| 36 |
+
// In this case, we first calculate the size of top trapezoid, and then
|
| 37 |
+
// calculate the size of the bottom rectangle.
|
| 38 |
+
inline int64_t get_tril_size(int64_t row, int64_t col, int64_t offset) {
|
| 39 |
+
// If either dimension is 0 then the there is no tril
|
| 40 |
+
if (row == 0 || col == 0) {
|
| 41 |
+
return 0;
|
| 42 |
+
}
|
| 43 |
+
// number of elements in the first row of the tril
|
| 44 |
+
auto m_first_row = offset > 0 ?
|
| 45 |
+
std::min<int64_t>(col, 1 + offset) : // upper bounded by col
|
| 46 |
+
row + offset > 0; // either 0 or 1
|
| 47 |
+
// number of elements in the last row of the tril, bounded by [0, col]
|
| 48 |
+
auto m_last_row = std::max<int64_t>(0, std::min<int64_t>(col, row + offset));
|
| 49 |
+
// number of rows, bounded by [0, row]
|
| 50 |
+
auto n_row_all = std::max<int64_t>(0, std::min<int64_t>(row, row + offset));
|
| 51 |
+
auto n_row_trapezoid = (m_last_row - m_first_row + 1);
|
| 52 |
+
|
| 53 |
+
// calculate # of elements in the top trapezoid
|
| 54 |
+
auto tril_size = (m_first_row + m_last_row) * n_row_trapezoid >> 1;
|
| 55 |
+
|
| 56 |
+
// calculate # of elements in the bottom rectangle if there is any
|
| 57 |
+
auto diff_row = n_row_all - n_row_trapezoid;
|
| 58 |
+
if (diff_row > 0) {
|
| 59 |
+
tril_size += diff_row * col;
|
| 60 |
+
}
|
| 61 |
+
|
| 62 |
+
return tril_size;
|
| 63 |
+
}
|
| 64 |
+
|
| 65 |
+
inline void check_args(
|
| 66 |
+
int64_t row, int64_t col, std::optional<Layout> layout_opt) {
|
| 67 |
+
TORCH_CHECK(row >= 0, "row must be non-negative, got", row);
|
| 68 |
+
TORCH_CHECK(col >= 0, "col must be non-negative, got", col);
|
| 69 |
+
if (layout_opt.has_value()) {
|
| 70 |
+
TORCH_CHECK(
|
| 71 |
+
*layout_opt == at::kStrided,
|
| 72 |
+
"only support layout=torch.strided, got",
|
| 73 |
+
*layout_opt)
|
| 74 |
+
}
|
| 75 |
+
}
|
| 76 |
+
|
| 77 |
+
using at::check_size_nonnegative;
|
| 78 |
+
|
| 79 |
+
// assumes maximum value in created tensor is n-1 (e.g., torch.randperm(n))
|
| 80 |
+
inline void check_supported_max_int_with_precision(int64_t n, const Tensor& tensor) {
|
| 81 |
+
// match defined() to behavior of checks below
|
| 82 |
+
TORCH_CHECK(at::scalar_tensor(n>0?n-1:n, tensor.options()).defined(),
|
| 83 |
+
"n is too large for result tensor type: '", tensor.toString(), "'");
|
| 84 |
+
|
| 85 |
+
// Ensure sufficient precision for floating point representation.
|
| 86 |
+
switch (tensor.scalar_type()) {
|
| 87 |
+
case at::ScalarType::Half:
|
| 88 |
+
TORCH_CHECK(n <= (int64_t(1) << 11) + 1, "n cannot be greater than 2049 for Half type.");
|
| 89 |
+
break;
|
| 90 |
+
case at::ScalarType::Float:
|
| 91 |
+
TORCH_CHECK(n <= (int64_t(1) << 24) + 1, "n cannot be greater than 2^24+1 for Float type.");
|
| 92 |
+
break;
|
| 93 |
+
case at::ScalarType::Double: // Unlikely to happen, but doesn't hurt to check
|
| 94 |
+
TORCH_CHECK(n <= (int64_t(1) << 53) + 1, "n cannot be greater than 2^53+1 for Double type.");
|
| 95 |
+
break;
|
| 96 |
+
default:
|
| 97 |
+
break;
|
| 98 |
+
}
|
| 99 |
+
}
|
| 100 |
+
|
| 101 |
+
// Called by `empty*` functions when deterministic algorithms are enabled to
|
| 102 |
+
// fill the tensor with NaN if it is floating point or complex type, or fill
|
| 103 |
+
// with max value if it is integer type
|
| 104 |
+
inline Tensor& fill_empty_deterministic_(Tensor& tensor) {
|
| 105 |
+
if (tensor.is_floating_point() || tensor.is_complex()) {
|
| 106 |
+
AT_DISPATCH_V2(
|
| 107 |
+
tensor.scalar_type(), "fill_empty_deterministic_", AT_WRAP([&]() {
|
| 108 |
+
tensor.fill_(std::numeric_limits<scalar_t>::quiet_NaN());
|
| 109 |
+
}), AT_EXPAND(AT_FLOATING_TYPES), AT_EXPAND(AT_COMPLEX_TYPES), AT_EXPAND(AT_FLOAT8_TYPES), kBFloat16, kHalf);
|
| 110 |
+
} else {
|
| 111 |
+
AT_DISPATCH_V2(
|
| 112 |
+
tensor.scalar_type(), "fill_empty_deterministic_", AT_WRAP([&]() {
|
| 113 |
+
tensor.fill_(std::numeric_limits<scalar_t>::max());
|
| 114 |
+
}), kBool, AT_EXPAND(AT_INTEGRAL_TYPES_V2));
|
| 115 |
+
}
|
| 116 |
+
return tensor;
|
| 117 |
+
}
|
| 118 |
+
|
| 119 |
+
// The ZeroTensor allocator ignores whatever allocation is requested and always
|
| 120 |
+
// gives you nullptr
|
| 121 |
+
struct ZeroTensorAllocator final : public at::Allocator {
|
| 122 |
+
ZeroTensorAllocator(at::Device device) : device_(device) {};
|
| 123 |
+
~ZeroTensorAllocator() override = default;
|
| 124 |
+
static void deleter(void* const pointer) {
|
| 125 |
+
TORCH_INTERNAL_ASSERT(!pointer);
|
| 126 |
+
}
|
| 127 |
+
DataPtr allocate(const size_t /*nbytes*/) override {
|
| 128 |
+
return {nullptr, nullptr, &deleter, device_};
|
| 129 |
+
}
|
| 130 |
+
DeleterFnPtr raw_deleter() const override {
|
| 131 |
+
return deleter;
|
| 132 |
+
}
|
| 133 |
+
void copy_data(void* dest [[maybe_unused]], const void* src [[maybe_unused]], std::size_t count [[maybe_unused]]) const final {}
|
| 134 |
+
at::Device device_;
|
| 135 |
+
};
|
| 136 |
+
|
| 137 |
+
using binary_fn = void (*)(TensorIterator&);
|
| 138 |
+
|
| 139 |
+
DECLARE_DISPATCH(binary_fn, complex_stub);
|
| 140 |
+
DECLARE_DISPATCH(binary_fn, polar_stub);
|
| 141 |
+
|
| 142 |
+
} // namespace at::native
|
.venv/lib/python3.11/site-packages/torch/include/ATen/native/TensorIteratorDynamicCasting.h
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <complex>
|
| 4 |
+
#include <type_traits>
|
| 5 |
+
#include <c10/core/ScalarType.h>
|
| 6 |
+
#include <ATen/detail/FunctionTraits.h>
|
| 7 |
+
#include <ATen/native/TensorIterator.h>
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
// This file includes utilities for dynamic_casting done by TensorIterator, see CUDALoops.cuh and Loops.h.
|
| 11 |
+
|
| 12 |
+
// dynamic_casting handles when the types expected by the iterator do not match the types of the arguments
|
| 13 |
+
// to the function that is being called.
|
| 14 |
+
// On CUDA, the cast is currently pushed down into the kernel (for performance reasons).
|
| 15 |
+
// On CPU, there is currently an internal assert that a dynamic_cast is not needed.
|
| 16 |
+
|
| 17 |
+
namespace at::native {
|
| 18 |
+
|
| 19 |
+
// `needs_dynamic_casting` compares the types expected by iterator
|
| 20 |
+
// (i.e. dtypes of the operands) with the actual type of the arguments
|
| 21 |
+
// (and returns) of func_t
|
| 22 |
+
template<typename func_t, int nargs=function_traits<func_t>::arity>
|
| 23 |
+
struct needs_dynamic_casting {
|
| 24 |
+
static bool check(TensorIteratorBase& iter) {
|
| 25 |
+
using traits = function_traits<func_t>;
|
| 26 |
+
using cpp_type = typename traits::template arg<nargs - 1>::type;
|
| 27 |
+
using cpp_map = c10::CppTypeToScalarType<cpp_type>;
|
| 28 |
+
|
| 29 |
+
if (iter.input_dtype(nargs-1) != cpp_map::value) {
|
| 30 |
+
return true;
|
| 31 |
+
}
|
| 32 |
+
return needs_dynamic_casting<func_t, nargs - 1>::check(iter);
|
| 33 |
+
}
|
| 34 |
+
};
|
| 35 |
+
|
| 36 |
+
template<typename func_t>
|
| 37 |
+
struct needs_dynamic_casting<func_t, 0> {
|
| 38 |
+
static bool check(TensorIteratorBase& iter) {
|
| 39 |
+
using traits = function_traits<func_t>;
|
| 40 |
+
using cpp_type = typename traits::result_type;
|
| 41 |
+
|
| 42 |
+
// we could assert output numbers are correct here, but checks
|
| 43 |
+
// (including arity) are currently pushed outside of this struct.
|
| 44 |
+
if constexpr (std::is_void_v<cpp_type>) {
|
| 45 |
+
return false;
|
| 46 |
+
} else {
|
| 47 |
+
return iter.dtype(0) != c10::CppTypeToScalarType<cpp_type>::value;
|
| 48 |
+
}
|
| 49 |
+
}
|
| 50 |
+
};
|
| 51 |
+
|
| 52 |
+
} //namespace at::native
|
.venv/lib/python3.11/site-packages/torch/include/ATen/native/TensorShape.h
ADDED
|
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
#include <ATen/core/Tensor.h>
|
| 3 |
+
#include <c10/util/irange.h>
|
| 4 |
+
#include <ATen/core/IListRef.h>
|
| 5 |
+
|
| 6 |
+
namespace at::native {
|
| 7 |
+
|
| 8 |
+
TORCH_API at::Tensor clone_preserve_strides(const at::Tensor& self);
|
| 9 |
+
|
| 10 |
+
inline bool cat_should_skip_tensor(const Tensor& t) {
|
| 11 |
+
return t.sym_numel() == 0 && t.dim() == 1;
|
| 12 |
+
}
|
| 13 |
+
|
| 14 |
+
// Check to see if the shape of tensors is compatible
|
| 15 |
+
// for being concatenated along a given dimension.
|
| 16 |
+
inline void check_cat_shape_except_dim(const Tensor & first, const Tensor & second, int64_t dimension, int64_t index) {
|
| 17 |
+
int64_t first_dims = first.dim();
|
| 18 |
+
int64_t second_dims = second.dim();
|
| 19 |
+
TORCH_CHECK(first_dims == second_dims, "Tensors must have same number of dimensions: got ",
|
| 20 |
+
first_dims, " and ", second_dims);
|
| 21 |
+
for (const auto dim : c10::irange(first_dims)) {
|
| 22 |
+
if (dim == dimension) {
|
| 23 |
+
continue;
|
| 24 |
+
}
|
| 25 |
+
int64_t first_dim_size = first.sizes()[dim];
|
| 26 |
+
int64_t second_dim_size = second.sizes()[dim];
|
| 27 |
+
TORCH_CHECK(first_dim_size == second_dim_size, "Sizes of tensors must match except in dimension ",
|
| 28 |
+
dimension, ". Expected size ", static_cast<long long>(first_dim_size), " but got size ", static_cast<long long>(second_dim_size), " for tensor number ", index, " in the list.");
|
| 29 |
+
}
|
| 30 |
+
}
|
| 31 |
+
|
| 32 |
+
inline void check_cat_no_zero_dim(const MaterializedITensorListRef& tensors) {
|
| 33 |
+
int64_t i = 0;
|
| 34 |
+
for(const Tensor& t : tensors) {
|
| 35 |
+
TORCH_CHECK(t.dim() > 0,
|
| 36 |
+
"zero-dimensional tensor (at position ", i, ") cannot be concatenated");
|
| 37 |
+
i++;
|
| 38 |
+
}
|
| 39 |
+
}
|
| 40 |
+
|
| 41 |
+
inline int64_t get_num_splits(const Tensor& self, int64_t split_size, int64_t dim) {
|
| 42 |
+
TORCH_CHECK(self.dim() != 0, "split expects at least a 1-dimensional tensor");
|
| 43 |
+
TORCH_CHECK(split_size >= 0, "split expects split_size be non-negative, but got split_size=", split_size);
|
| 44 |
+
int64_t dim_size = self.size(dim);
|
| 45 |
+
TORCH_CHECK(split_size > 0 || dim_size == 0,
|
| 46 |
+
"split_size can only be 0 if dimension size is 0, "
|
| 47 |
+
"but got dimension size of ", dim_size);
|
| 48 |
+
// if split_size is 0 and dimension size is 0, there is 1 split.
|
| 49 |
+
int64_t num_splits = 1;
|
| 50 |
+
if (split_size != 0) {
|
| 51 |
+
// ensuring num_splits is at least 1 makes consistent the case where split_size > dim_size
|
| 52 |
+
// (returns a single split). We might want to error here, but keep it for BC.
|
| 53 |
+
num_splits = std::max<int64_t>((dim_size + split_size - 1) / split_size, 1);
|
| 54 |
+
}
|
| 55 |
+
return num_splits;
|
| 56 |
+
}
|
| 57 |
+
|
| 58 |
+
inline bool have_same_ndims(TensorList tensors) {
|
| 59 |
+
auto ndim = tensors[0].dim();
|
| 60 |
+
for (const auto tensor_idx : c10::irange(tensors.size())) {
|
| 61 |
+
if(tensors[tensor_idx].dim() != ndim) {
|
| 62 |
+
return false;
|
| 63 |
+
}
|
| 64 |
+
}
|
| 65 |
+
return true;
|
| 66 |
+
}
|
| 67 |
+
|
| 68 |
+
inline void leading_dimension_matches(TensorList tensors, int64_t dim) {
|
| 69 |
+
auto tensor_zero_size = tensors[0].sizes();
|
| 70 |
+
std::vector<c10::SymInt> leading_dim_sizes(tensor_zero_size.begin(), tensor_zero_size.begin() + dim);
|
| 71 |
+
for (const auto i : c10::irange(tensors.size())) {
|
| 72 |
+
at::Tensor tensor = tensors[i];
|
| 73 |
+
for(const auto j : c10::irange(dim)) {
|
| 74 |
+
TORCH_CHECK(
|
| 75 |
+
tensor.size(j) == leading_dim_sizes[j],
|
| 76 |
+
"_chunk_cat expects same sizes of 0,...,dim-1 dimensions for all tensors"
|
| 77 |
+
);
|
| 78 |
+
}
|
| 79 |
+
}
|
| 80 |
+
}
|
| 81 |
+
|
| 82 |
+
inline int64_t preprocess_chunk_cat_inputs(TensorList tensors, int64_t dim, int64_t num_chunks) {
|
| 83 |
+
TORCH_CHECK(num_chunks >= 1, "_chunk_cat expects positive num_chunks");
|
| 84 |
+
TORCH_CHECK(!tensors.empty(),
|
| 85 |
+
"_chunk_cat expects a non-empty input tensor list");
|
| 86 |
+
auto expected_dtype = tensors[0].dtype();
|
| 87 |
+
auto expected_device = tensors[0].device();
|
| 88 |
+
for(const auto i : c10::irange(tensors.size())) {
|
| 89 |
+
TORCH_CHECK(tensors[i].numel() > 0, "_chunk_cat expects non-empty tensor");
|
| 90 |
+
TORCH_CHECK(tensors[i].dtype() == expected_dtype, "_chunk_cat expects all input tensors with the same dtype");
|
| 91 |
+
TORCH_CHECK(tensors[i].device() == expected_device, "_chunk_cat expects all inputs tensors on the same device");
|
| 92 |
+
}
|
| 93 |
+
if (have_same_ndims(tensors)) {
|
| 94 |
+
dim = maybe_wrap_dim(dim, tensors[0].dim());
|
| 95 |
+
} else {
|
| 96 |
+
TORCH_CHECK(dim >= 0, "_chunk_cat expects non-negative dim when input tensors have different ndims")
|
| 97 |
+
for(const auto i : c10::irange(tensors.size())) {
|
| 98 |
+
TORCH_CHECK(dim < tensors[i].ndimension(), "_chunk_cat expects dim < ndim for all input tensors");
|
| 99 |
+
}
|
| 100 |
+
}
|
| 101 |
+
leading_dimension_matches(tensors, dim);
|
| 102 |
+
return dim;
|
| 103 |
+
}
|
| 104 |
+
|
| 105 |
+
} // namespace at::native
|
.venv/lib/python3.11/site-packages/torch/include/ATen/native/TriangularOpsUtils.h
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <ATen/core/Tensor.h>
|
| 2 |
+
#include <ATen/native/LinearAlgebraUtils.h>
|
| 3 |
+
|
| 4 |
+
namespace at::native {
|
| 5 |
+
|
| 6 |
+
/*
|
| 7 |
+
* Given batches of matrices with arbitrary batch dim,
|
| 8 |
+
* computes the number of batches for Triu and Tril. This ignores stride 0 dimension
|
| 9 |
+
*/
|
| 10 |
+
static inline int64_t batchCountTrilTriu(const Tensor& batched_matrices) {
|
| 11 |
+
int64_t result = 1;
|
| 12 |
+
for (int64_t i = 0; i < batched_matrices.ndimension() - 2; i++) {
|
| 13 |
+
if (batched_matrices.stride(i) != 0) {
|
| 14 |
+
result *= batched_matrices.size(i);
|
| 15 |
+
}
|
| 16 |
+
}
|
| 17 |
+
return result;
|
| 18 |
+
}
|
| 19 |
+
|
| 20 |
+
/* Checks a necessary property for the triu and tril implementations, hence the name.
|
| 21 |
+
* Here batch contiguity is checked for tensors with greater than 4 dimensions.
|
| 22 |
+
* Contiguous tensors and tensors with less than 3 dimensions pass this check
|
| 23 |
+
*/
|
| 24 |
+
static inline std::tuple<bool, Tensor> checkTrilTriuBatchContiguous(const Tensor& tensor, bool allow_zero_stride) {
|
| 25 |
+
// Complete contiguity is the most desired property, which is why
|
| 26 |
+
// we return true if the tensor is contiguous
|
| 27 |
+
if (tensor.is_contiguous()) {
|
| 28 |
+
auto default_strides_for_size = batched_matrix_contiguous_strides(tensor.sizes());
|
| 29 |
+
if (tensor.strides() == default_strides_for_size) {
|
| 30 |
+
return std::make_tuple(true, tensor);
|
| 31 |
+
} else {
|
| 32 |
+
return std::make_tuple(false, tensor.as_strided(tensor.sizes(), default_strides_for_size));
|
| 33 |
+
}
|
| 34 |
+
}
|
| 35 |
+
|
| 36 |
+
int64_t dims = tensor.dim();
|
| 37 |
+
|
| 38 |
+
// Tensors with dimension less than 4 are handled by default
|
| 39 |
+
if (allow_zero_stride && dims <= 3) {
|
| 40 |
+
return std::make_tuple(true, tensor);
|
| 41 |
+
}
|
| 42 |
+
|
| 43 |
+
int64_t expected_stride = tensor.size(-1) * tensor.size(-2);
|
| 44 |
+
for (int64_t i = dims - 3; i >= 0; i--) {
|
| 45 |
+
// Skip trivial dimension;
|
| 46 |
+
if (allow_zero_stride && i == 0 && (tensor.stride(i) == 0 || tensor.size(i) == 1)) {
|
| 47 |
+
continue;
|
| 48 |
+
}
|
| 49 |
+
if (expected_stride != tensor.stride(i)) {
|
| 50 |
+
return std::make_tuple(false, tensor.contiguous());
|
| 51 |
+
}
|
| 52 |
+
expected_stride *= tensor.size(i);
|
| 53 |
+
}
|
| 54 |
+
return std::make_tuple(true, tensor);
|
| 55 |
+
}
|
| 56 |
+
|
| 57 |
+
} // namespace at::native
|
.venv/lib/python3.11/site-packages/torch/include/ATen/native/UnaryOps.h
ADDED
|
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/native/DispatchStub.h>
|
| 4 |
+
#include <ATen/Generator.h>
|
| 5 |
+
#include <c10/core/Scalar.h>
|
| 6 |
+
#include <stdexcept>
|
| 7 |
+
|
| 8 |
+
namespace at {
|
| 9 |
+
class Tensor;
|
| 10 |
+
class TensorBase;
|
| 11 |
+
struct TensorIteratorBase;
|
| 12 |
+
}
|
| 13 |
+
|
| 14 |
+
namespace at::native {
|
| 15 |
+
|
| 16 |
+
using unary_fn = void(*)(TensorIteratorBase&);
|
| 17 |
+
using unary_fn_with_scalar = void(*)(TensorIteratorBase&, const Scalar& a);
|
| 18 |
+
|
| 19 |
+
inline namespace CPU_CAPABILITY {
|
| 20 |
+
void conj_kernel(TensorIteratorBase &iter);
|
| 21 |
+
void neg_kernel(TensorIteratorBase &iter);
|
| 22 |
+
void reciprocal_kernel(TensorIteratorBase &iter);
|
| 23 |
+
void rsqrt_kernel(TensorIteratorBase& iter);
|
| 24 |
+
void sqrt_kernel(TensorIteratorBase& iter);
|
| 25 |
+
} // namespace CPU_CAPABILITY
|
| 26 |
+
|
| 27 |
+
DECLARE_DISPATCH(unary_fn, abs_stub);
|
| 28 |
+
DECLARE_DISPATCH(unary_fn, angle_stub);
|
| 29 |
+
DECLARE_DISPATCH(unary_fn, conj_physical_stub);
|
| 30 |
+
DECLARE_DISPATCH(unary_fn, acos_stub);
|
| 31 |
+
DECLARE_DISPATCH(unary_fn, acosh_stub);
|
| 32 |
+
DECLARE_DISPATCH(unary_fn, asinh_stub);
|
| 33 |
+
DECLARE_DISPATCH(unary_fn, atanh_stub);
|
| 34 |
+
DECLARE_DISPATCH(unary_fn, asin_stub);
|
| 35 |
+
DECLARE_DISPATCH(unary_fn, atan_stub);
|
| 36 |
+
DECLARE_DISPATCH(unary_fn, bitwise_not_stub);
|
| 37 |
+
DECLARE_DISPATCH(unary_fn, logical_not_stub);
|
| 38 |
+
DECLARE_DISPATCH(unary_fn, ceil_stub);
|
| 39 |
+
DECLARE_DISPATCH(unary_fn, cos_stub);
|
| 40 |
+
DECLARE_DISPATCH(unary_fn, cosh_stub);
|
| 41 |
+
DECLARE_DISPATCH(unary_fn, digamma_stub);
|
| 42 |
+
DECLARE_DISPATCH(unary_fn, special_entr_stub);
|
| 43 |
+
DECLARE_DISPATCH(unary_fn, special_erfcx_stub);
|
| 44 |
+
DECLARE_DISPATCH(unary_fn, erf_stub);
|
| 45 |
+
DECLARE_DISPATCH(unary_fn, erfc_stub);
|
| 46 |
+
DECLARE_DISPATCH(unary_fn, erfinv_stub);
|
| 47 |
+
DECLARE_DISPATCH(unary_fn, exp_stub);
|
| 48 |
+
DECLARE_DISPATCH(unary_fn, exp2_stub);
|
| 49 |
+
DECLARE_DISPATCH(unary_fn, expm1_stub);
|
| 50 |
+
DECLARE_DISPATCH(unary_fn, floor_stub);
|
| 51 |
+
DECLARE_DISPATCH(unary_fn, frac_stub);
|
| 52 |
+
DECLARE_DISPATCH(unary_fn, frexp_stub);
|
| 53 |
+
DECLARE_DISPATCH(unary_fn, i0_stub);
|
| 54 |
+
DECLARE_DISPATCH(unary_fn, special_i0e_stub);
|
| 55 |
+
DECLARE_DISPATCH(unary_fn, special_i1_stub);
|
| 56 |
+
DECLARE_DISPATCH(unary_fn, special_i1e_stub);
|
| 57 |
+
DECLARE_DISPATCH(unary_fn, log_stub);
|
| 58 |
+
DECLARE_DISPATCH(unary_fn, log10_stub);
|
| 59 |
+
DECLARE_DISPATCH(unary_fn, log1p_stub);
|
| 60 |
+
DECLARE_DISPATCH(unary_fn, log2_stub);
|
| 61 |
+
DECLARE_DISPATCH(unary_fn, special_ndtri_stub);
|
| 62 |
+
DECLARE_DISPATCH(unary_fn, special_log_ndtr_stub);
|
| 63 |
+
DECLARE_DISPATCH(unary_fn, neg_stub);
|
| 64 |
+
|
| 65 |
+
DECLARE_DISPATCH(unary_fn, reciprocal_stub);
|
| 66 |
+
DECLARE_DISPATCH(unary_fn, round_stub);
|
| 67 |
+
DECLARE_DISPATCH(unary_fn, rsqrt_stub);
|
| 68 |
+
DECLARE_DISPATCH(unary_fn, sigmoid_stub);
|
| 69 |
+
DECLARE_DISPATCH(unary_fn_with_scalar, logit_stub);
|
| 70 |
+
DECLARE_DISPATCH(unary_fn, sign_stub);
|
| 71 |
+
DECLARE_DISPATCH(unary_fn, signbit_stub);
|
| 72 |
+
DECLARE_DISPATCH(unary_fn, sgn_stub);
|
| 73 |
+
DECLARE_DISPATCH(unary_fn, sin_stub);
|
| 74 |
+
DECLARE_DISPATCH(unary_fn, sinc_stub);
|
| 75 |
+
DECLARE_DISPATCH(unary_fn, sinh_stub);
|
| 76 |
+
DECLARE_DISPATCH(unary_fn, sqrt_stub);
|
| 77 |
+
DECLARE_DISPATCH(unary_fn, tan_stub);
|
| 78 |
+
DECLARE_DISPATCH(unary_fn, tanh_stub);
|
| 79 |
+
DECLARE_DISPATCH(unary_fn, trigamma_stub);
|
| 80 |
+
DECLARE_DISPATCH(unary_fn, trunc_stub);
|
| 81 |
+
DECLARE_DISPATCH(unary_fn, lgamma_stub);
|
| 82 |
+
DECLARE_DISPATCH(unary_fn, special_airy_ai_stub);
|
| 83 |
+
DECLARE_DISPATCH(unary_fn, special_bessel_j0_stub);
|
| 84 |
+
DECLARE_DISPATCH(unary_fn, special_bessel_j1_stub);
|
| 85 |
+
DECLARE_DISPATCH(unary_fn, special_bessel_y0_stub);
|
| 86 |
+
DECLARE_DISPATCH(unary_fn, special_bessel_y1_stub);
|
| 87 |
+
DECLARE_DISPATCH(unary_fn, special_modified_bessel_i0_stub);
|
| 88 |
+
DECLARE_DISPATCH(unary_fn, special_modified_bessel_i1_stub);
|
| 89 |
+
DECLARE_DISPATCH(unary_fn, special_modified_bessel_k0_stub);
|
| 90 |
+
DECLARE_DISPATCH(unary_fn, special_modified_bessel_k1_stub);
|
| 91 |
+
DECLARE_DISPATCH(unary_fn, special_scaled_modified_bessel_k0_stub);
|
| 92 |
+
DECLARE_DISPATCH(unary_fn, special_scaled_modified_bessel_k1_stub);
|
| 93 |
+
DECLARE_DISPATCH(unary_fn, special_spherical_bessel_j0_stub);
|
| 94 |
+
|
| 95 |
+
// NB: these are actually defined in Distribution
|
| 96 |
+
DECLARE_DISPATCH(void(*)(const TensorBase&, const TensorBase&, std::optional<Generator>), bernoulli_tensor_stub);
|
| 97 |
+
DECLARE_DISPATCH(void(*)(const TensorBase&, const double, std::optional<Generator>), bernoulli_scalar_stub);
|
| 98 |
+
DECLARE_DISPATCH(void(*)(TensorIteratorBase&, const double, const double, std::optional<Generator>), cauchy_stub);
|
| 99 |
+
DECLARE_DISPATCH(void(*)(TensorIteratorBase&, const double, std::optional<Generator>), exponential_stub);
|
| 100 |
+
DECLARE_DISPATCH(void(*)(TensorIteratorBase&, const double, std::optional<Generator>), geometric_stub);
|
| 101 |
+
DECLARE_DISPATCH(void(*)(TensorIteratorBase&, const double, const double, std::optional<Generator>), log_normal_stub);
|
| 102 |
+
DECLARE_DISPATCH(void(*)(TensorIteratorBase&, const double, const double, std::optional<Generator>), uniform_stub);
|
| 103 |
+
DECLARE_DISPATCH(void(*)(const TensorBase&, const double, const double, std::optional<Generator>), normal_stub);
|
| 104 |
+
DECLARE_DISPATCH(void(*)(TensorIteratorBase&, const uint64_t, const int64_t, std::optional<Generator>), random_from_to_stub);
|
| 105 |
+
DECLARE_DISPATCH(void(*)(TensorIteratorBase&, std::optional<Generator>), random_full_64_bits_range_stub);
|
| 106 |
+
DECLARE_DISPATCH(void(*)(TensorIteratorBase&, std::optional<Generator>), random_stub);
|
| 107 |
+
|
| 108 |
+
DECLARE_DISPATCH(void(*)(TensorIteratorBase&, const int64_t, const double), kaiser_window_stub);
|
| 109 |
+
DECLARE_DISPATCH(void(*)(TensorIteratorBase&, const int64_t), polygamma_stub);
|
| 110 |
+
DECLARE_DISPATCH(void(*)(TensorIteratorBase&, const Scalar& a, const Scalar& b), clamp_stub);
|
| 111 |
+
DECLARE_DISPATCH(
|
| 112 |
+
void (*)(Tensor&, const Tensor&, int64_t, std::optional<Generator>),
|
| 113 |
+
multinomial_with_replacement_stub);
|
| 114 |
+
DECLARE_DISPATCH(
|
| 115 |
+
void (*)(
|
| 116 |
+
TensorIteratorBase&,
|
| 117 |
+
std::optional<double>,
|
| 118 |
+
std::optional<double>,
|
| 119 |
+
std::optional<double>),
|
| 120 |
+
nan_to_num_stub);
|
| 121 |
+
DECLARE_DISPATCH(void (*)(TensorIteratorBase&, int64_t), round_decimals_stub);
|
| 122 |
+
|
| 123 |
+
// Missing unary functions
|
| 124 |
+
// digamma
|
| 125 |
+
// lgamma
|
| 126 |
+
// erfinv
|
| 127 |
+
// clone
|
| 128 |
+
// contiguous
|
| 129 |
+
// zero
|
| 130 |
+
} // namespace at::native
|
.venv/lib/python3.11/site-packages/torch/include/ATen/native/Unfold2d.h
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/native/DispatchStub.h>
|
| 4 |
+
#include <c10/core/ScalarType.h>
|
| 5 |
+
#include <cstdint>
|
| 6 |
+
|
| 7 |
+
namespace at::native {
|
| 8 |
+
|
| 9 |
+
using unfold2d_copy_fn = void (*)(
|
| 10 |
+
ScalarType dtype,
|
| 11 |
+
void *finput,
|
| 12 |
+
const void *input,
|
| 13 |
+
int64_t kH,
|
| 14 |
+
int64_t kW,
|
| 15 |
+
int64_t dH,
|
| 16 |
+
int64_t dW,
|
| 17 |
+
int64_t padH,
|
| 18 |
+
int64_t padW,
|
| 19 |
+
int64_t n_input_plane,
|
| 20 |
+
int64_t input_height,
|
| 21 |
+
int64_t input_width,
|
| 22 |
+
int64_t output_height,
|
| 23 |
+
int64_t output_width,
|
| 24 |
+
bool is_channels_last
|
| 25 |
+
);
|
| 26 |
+
|
| 27 |
+
using unfold2d_acc_fn = void (*)(
|
| 28 |
+
ScalarType dtype,
|
| 29 |
+
void *finput,
|
| 30 |
+
void *input,
|
| 31 |
+
int64_t kH,
|
| 32 |
+
int64_t kW,
|
| 33 |
+
int64_t dH,
|
| 34 |
+
int64_t dW,
|
| 35 |
+
int64_t padH,
|
| 36 |
+
int64_t padW,
|
| 37 |
+
int64_t n_input_plane,
|
| 38 |
+
int64_t input_height,
|
| 39 |
+
int64_t input_width,
|
| 40 |
+
int64_t output_height,
|
| 41 |
+
int64_t output_width,
|
| 42 |
+
bool is_channels_last
|
| 43 |
+
);
|
| 44 |
+
|
| 45 |
+
DECLARE_DISPATCH(unfold2d_copy_fn, unfolded2d_copy_stub);
|
| 46 |
+
DECLARE_DISPATCH(unfold2d_acc_fn, unfolded2d_acc_stub);
|
| 47 |
+
|
| 48 |
+
} // namespace at::native
|