|
|
#pragma once
|
|
|
|
|
|
#include <c10/core/ScalarType.h>
|
|
|
#include <c10/util/irange.h>
|
|
|
#include <c10/util/Exception.h>
|
|
|
#include <c10/util/strides.h>
|
|
|
#include <ATen/core/Tensor.h>
|
|
|
#include <ATen/ExpandUtils.h>
|
|
|
#include <ATen/TensorUtils.h>
|
|
|
#include <ATen/native/TensorIterator.h>
|
|
|
#include <ATen/native/TransposeType.h>
|
|
|
#include <limits>
|
|
|
#include <type_traits>
|
|
|
#include <sstream>
|
|
|
#include <cstring>
|
|
|
#include <cctype>
|
|
|
|
|
|
#ifndef AT_PER_OPERATOR_HEADERS
|
|
|
#include <ATen/Functions.h>
|
|
|
#else
|
|
|
#include <ATen/ops/arange.h>
|
|
|
#include <ATen/ops/empty.h>
|
|
|
#include <ATen/ops/empty_like.h>
|
|
|
#include <ATen/ops/empty_strided.h>
|
|
|
#include <ATen/ops/zeros.h>
|
|
|
#endif
|
|
|
|
|
|
namespace at::native {
|
|
|
|
|
|
inline c10::MaybeOwned<Tensor> expect_resolved_conj(const Tensor& tensor) {
|
|
|
if (tensor.is_conj()) {
|
|
|
return c10::MaybeOwned<Tensor>::owned(tensor.resolve_conj());
|
|
|
} else {
|
|
|
return c10::MaybeOwned<Tensor>::borrowed(tensor);
|
|
|
}
|
|
|
}
|
|
|
|
|
|
inline DimVector batched_matrix_contiguous_strides(
|
|
|
const IntArrayRef sizes,
|
|
|
const bool f_contig = false) {
|
|
|
|
|
|
|
|
|
auto strides = c10::contiguous_strides(sizes);
|
|
|
auto dim = strides.size();
|
|
|
|
|
|
if (f_contig && dim >= 2) {
|
|
|
|
|
|
|
|
|
strides[dim - 1] = std::max(sizes[dim - 2], static_cast<int64_t>(1));
|
|
|
strides[dim - 2] = 1;
|
|
|
}
|
|
|
return strides;
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
inline Tensor cloneBatchedColumnMajor(const Tensor& src) {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
auto result = src.mT().clone(at::MemoryFormat::Contiguous);
|
|
|
result.transpose_(-2, -1);
|
|
|
return result;
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
inline c10::MaybeOwned<Tensor> borrow_else_clone(const bool cond, const Tensor& borrow, const Tensor& clone, const bool contig) {
|
|
|
return cond ? c10::MaybeOwned<Tensor>::borrowed(borrow)
|
|
|
: c10::MaybeOwned<Tensor>::owned(contig ? clone.clone(MemoryFormat::Contiguous)
|
|
|
: cloneBatchedColumnMajor(clone));
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
inline Tensor copyBatchedColumnMajor(const Tensor& src, int64_t nrows = -1,
|
|
|
at::OptionalIntArrayRef desired_batch_sizes = std::nullopt) {
|
|
|
nrows = (nrows == -1) ? src.size(-2) : nrows;
|
|
|
auto copy_sizes = desired_batch_sizes.has_value()
|
|
|
? desired_batch_sizes.value().vec()
|
|
|
: IntArrayRef(src.sizes().data(), src.dim() - 2).vec();
|
|
|
copy_sizes.insert(copy_sizes.end(), {nrows, src.size(-1)});
|
|
|
const auto copy_strides = batched_matrix_contiguous_strides(copy_sizes, true);
|
|
|
auto copy = at::empty_strided(copy_sizes, copy_strides, src.options());
|
|
|
copy.narrow(-2, 0, src.size(-2)).copy_(src);
|
|
|
return copy;
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
inline int64_t batchCount(const Tensor& batched_matrices) {
|
|
|
int64_t result = 1;
|
|
|
for (int64_t i = 0; i < batched_matrices.ndimension() - 2; i++) {
|
|
|
result *= batched_matrices.size(i);
|
|
|
}
|
|
|
return result;
|
|
|
}
|
|
|
|
|
|
|
|
|
inline int64_t matrixStride(const Tensor& batched_matrices) {
|
|
|
return batched_matrices.size(-1) * batched_matrices.size(-2);
|
|
|
}
|
|
|
|
|
|
|
|
|
inline void checkIsMatrix(const Tensor& A, const char* const f_name, const char* const arg_name = "A") {
|
|
|
TORCH_CHECK(A.dim() >= 2, f_name, ": The input tensor ", arg_name, " must have at least 2 dimensions.");
|
|
|
}
|
|
|
inline void squareCheckInputs(const Tensor& self, const char* const f_name, const char* const arg_name = "A") {
|
|
|
checkIsMatrix(self, f_name, arg_name);
|
|
|
TORCH_CHECK(self.sym_size(-1) == self.sym_size(-2),
|
|
|
f_name,
|
|
|
": ", arg_name, " must be batches of square matrices, "
|
|
|
"but they are ", self.sym_size(-2), " by ", self.sym_size(-1), " matrices");
|
|
|
}
|
|
|
|
|
|
inline void checkInputsSolver(const Tensor& A,
|
|
|
const Tensor& B,
|
|
|
const bool left,
|
|
|
const char* const f_name) {
|
|
|
squareCheckInputs(A, f_name, "A");
|
|
|
checkIsMatrix(B, f_name, "B");
|
|
|
TORCH_CHECK(left ? A.size(-2) == B.size(-2) : A.size(-1) == B.size(-1),
|
|
|
f_name, ": Incompatible shapes of A and B for the equation ",
|
|
|
left ? "AX = B" : "XA = B",
|
|
|
" (", A.size(-2), "x", A.size(-1), " and ", B.size(-2), "x", B.size(-1), ")");
|
|
|
}
|
|
|
|
|
|
inline bool is_row_or_column_contiguous(const Tensor& t) {
|
|
|
|
|
|
|
|
|
|
|
|
return t.is_contiguous() || t.transpose(-2, -1).is_contiguous();
|
|
|
}
|
|
|
|
|
|
inline TransposeType to_transpose_type(const bool contig, const bool conj) {
|
|
|
if (conj) {
|
|
|
if (contig) { TORCH_INTERNAL_ASSERT(false, "Invalid transpose type"); }
|
|
|
else { return TransposeType::ConjTranspose; }
|
|
|
} else {
|
|
|
if (contig) { return TransposeType::NoTranspose; }
|
|
|
else { return TransposeType::Transpose; }
|
|
|
}
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
template<typename scalar_t, typename func_t>
|
|
|
void batch_iterator_with_broadcasting(const Tensor& a, const Tensor& b, const func_t& f) {
|
|
|
IntArrayRef a_batch_sizes(a.sizes().data(), a.dim() - 2);
|
|
|
IntArrayRef b_batch_sizes(b.sizes().data(), b.dim() - 2);
|
|
|
|
|
|
auto a_linear_batch_idx = at::arange(batchCount(a)).view(a_batch_sizes);
|
|
|
auto b_linear_batch_idx = at::arange(batchCount(b)).view(b_batch_sizes);
|
|
|
|
|
|
TensorIterator iter = TensorIteratorConfig()
|
|
|
.set_check_mem_overlap(false)
|
|
|
.check_all_same_dtype(false)
|
|
|
.resize_outputs(false)
|
|
|
.add_output(b_linear_batch_idx)
|
|
|
.add_input(a_linear_batch_idx)
|
|
|
.build();
|
|
|
|
|
|
auto m = a.size(-2);
|
|
|
auto n = a.size(-1);
|
|
|
auto a_3d = a.view({batchCount(a), m, n});
|
|
|
auto b_3d = b.view({batchCount(b), b.size(-2), b.size(-1)});
|
|
|
|
|
|
auto a_broadcasts_over_b = (a_batch_sizes != b_batch_sizes);
|
|
|
Tensor a_buffer, a_was_accessed, a_buffer_3d;
|
|
|
std::function<void(int64_t)> check_if_copy_needed_for_a
|
|
|
= [](int64_t ){};
|
|
|
if (a_broadcasts_over_b) {
|
|
|
a_buffer = at::empty_strided(a.sizes(), a.strides(), a.options())
|
|
|
.copy_(a);
|
|
|
a_was_accessed = at::zeros(batchCount(a), at::kBool);
|
|
|
a_buffer_3d = a_buffer.view({batchCount(a), m, n});
|
|
|
check_if_copy_needed_for_a = [&](int64_t a_curr_linear_batch_idx) {
|
|
|
auto* a_was_accessed_flag = a_was_accessed
|
|
|
.select(0, a_curr_linear_batch_idx)
|
|
|
.data_ptr<bool>();
|
|
|
if (!(*a_was_accessed_flag)) {
|
|
|
*a_was_accessed_flag = true;
|
|
|
}
|
|
|
else {
|
|
|
a_3d.select(0, a_curr_linear_batch_idx)
|
|
|
.copy_(a_buffer_3d.select(0, a_curr_linear_batch_idx));
|
|
|
}
|
|
|
};
|
|
|
}
|
|
|
|
|
|
auto loop = [&](char** data, const int64_t* strides, int64_t nelems) {
|
|
|
auto* b_batch_idx_ptr = data[0];
|
|
|
auto* a_batch_idx_ptr = data[1];
|
|
|
|
|
|
for ([[maybe_unused]] const auto elem : c10::irange(nelems)) {
|
|
|
auto b_curr_linear_batch_idx =
|
|
|
*reinterpret_cast<int64_t*>(b_batch_idx_ptr);
|
|
|
auto a_curr_linear_batch_idx = *reinterpret_cast<int64_t*>(a_batch_idx_ptr);
|
|
|
|
|
|
check_if_copy_needed_for_a(a_curr_linear_batch_idx);
|
|
|
|
|
|
auto* a_working_ptr = a_3d.select(0, a_curr_linear_batch_idx)
|
|
|
.data_ptr<scalar_t>();
|
|
|
auto* b_working_ptr = b_3d.select(0, b_curr_linear_batch_idx)
|
|
|
.data_ptr<scalar_t>();
|
|
|
f(a_working_ptr, b_working_ptr, a_curr_linear_batch_idx);
|
|
|
|
|
|
b_batch_idx_ptr += strides[0];
|
|
|
a_batch_idx_ptr += strides[1];
|
|
|
}
|
|
|
};
|
|
|
iter.serial_for_each(loop, {0, batchCount(b)});
|
|
|
}
|
|
|
|
|
|
|
|
|
inline double _get_epsilon(const ScalarType& sc_type) {
|
|
|
switch (sc_type) {
|
|
|
case at::ScalarType::Float:
|
|
|
return static_cast<double>(std::numeric_limits<float>::epsilon());
|
|
|
case at::ScalarType::Double:
|
|
|
return std::numeric_limits<double>::epsilon();
|
|
|
default:
|
|
|
TORCH_CHECK(false, "This function doesn't handle types other than float and double");
|
|
|
}
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
inline void linearSolveCheckInputs(const Tensor& self, const Tensor& A, const char* name) {
|
|
|
TORCH_CHECK(self.device() == A.device(),
|
|
|
"Expected b and A to be on the same device, but found b on ",
|
|
|
self.device(), " and A on ", A.device(), " instead.");
|
|
|
|
|
|
TORCH_CHECK(self.scalar_type() == A.scalar_type(),
|
|
|
"Expected b and A to have the same dtype, but found b of type ",
|
|
|
self.scalar_type(), " and A of type ", A.scalar_type(), " instead.");
|
|
|
|
|
|
TORCH_CHECK(A.size(-1) == A.size(-2),
|
|
|
"A must be batches of square matrices, "
|
|
|
"but they are ", A.size(-2), " by ", A.size(-1), " matrices");
|
|
|
|
|
|
TORCH_CHECK(A.size(-1) == self.size(-2),
|
|
|
"Incompatible matrix sizes for ", name, ": each A "
|
|
|
"matrix is ", A.size(-1), " by ", A.size(-1),
|
|
|
" but each b matrix is ", self.size(-2), " by ", self.size(-1));
|
|
|
}
|
|
|
|
|
|
inline void checkFloatingOrComplex(const Tensor& t, const char* const f_name, const bool allow_low_precision_dtypes=true) {
|
|
|
auto dtype = t.scalar_type();
|
|
|
TORCH_CHECK((at::isFloatingType(dtype) || at::isComplexType(dtype)),
|
|
|
f_name, ": Expected a floating point or complex tensor as input. Got ", dtype);
|
|
|
if (!allow_low_precision_dtypes) {
|
|
|
TORCH_CHECK(dtype == kFloat || dtype == kDouble || dtype == kComplexFloat || dtype == kComplexDouble,
|
|
|
f_name, ": Low precision dtypes not supported. Got ", dtype);
|
|
|
}
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
inline void checkAllSameDim(TensorList tensors, int64_t dim) {
|
|
|
for (auto &t : tensors) {
|
|
|
TORCH_CHECK(t.dim() == dim, "Tensor dimension is ", t.dim(), ", expected ", dim, " instead.");
|
|
|
}
|
|
|
}
|
|
|
|
|
|
inline std::tuple<std::vector<int64_t>, std::vector<int64_t>> _linalg_broadcast_batch_dims(const Tensor& arg1, const Tensor& arg2) {
|
|
|
|
|
|
IntArrayRef arg1_batch_sizes(arg1.sizes().data(), arg1.ndimension() - 2);
|
|
|
IntArrayRef arg2_batch_sizes(arg2.sizes().data(), arg2.ndimension() - 2);
|
|
|
std::vector<int64_t> expand_batch_portion = infer_size(arg1_batch_sizes, arg2_batch_sizes);
|
|
|
|
|
|
std::vector<int64_t> arg1_expand_size({expand_batch_portion});
|
|
|
arg1_expand_size.insert(arg1_expand_size.end(), { arg1.size(-2), arg1.size(-1) });
|
|
|
|
|
|
std::vector<int64_t> arg2_expand_size({expand_batch_portion});
|
|
|
arg2_expand_size.insert(arg2_expand_size.end(), { arg2.size(-2), arg2.size(-1) });
|
|
|
return std::make_tuple(std::move(arg1_expand_size), std::move(arg2_expand_size));
|
|
|
}
|
|
|
|
|
|
inline std::tuple<Tensor,Tensor> _linalg_broadcast_batch_dims(const Tensor& arg1, const Tensor& arg2, const char* name) {
|
|
|
|
|
|
if (name != nullptr) {
|
|
|
linearSolveCheckInputs(arg1, arg2, name);
|
|
|
}
|
|
|
|
|
|
auto [arg1_expand_size, arg2_expand_size] = at::native::_linalg_broadcast_batch_dims(arg1, arg2);
|
|
|
|
|
|
auto arg1_broadcasted = arg1_expand_size == arg1.sizes() ? arg1 : arg1.expand(arg1_expand_size);
|
|
|
auto arg2_broadcasted = arg2_expand_size == arg2.sizes() ? arg2 : arg2.expand(arg2_expand_size);
|
|
|
return std::make_tuple(arg1_broadcasted, arg2_broadcasted);
|
|
|
}
|
|
|
|
|
|
inline std::vector<int64_t> broadcast_batch_size(const Tensor& t1, const Tensor& t2, int64_t n_batch_dims) {
|
|
|
IntArrayRef t1_batch_sizes(t1.sizes().data(), n_batch_dims);
|
|
|
IntArrayRef t2_batch_sizes(t2.sizes().data(), n_batch_dims);
|
|
|
auto broadcasted_batch_sizes = infer_size(t1_batch_sizes, t2_batch_sizes);
|
|
|
return broadcasted_batch_sizes;
|
|
|
}
|
|
|
|
|
|
|
|
|
inline Tensor _move_to_end(const Tensor& self, IntArrayRef axes) {
|
|
|
const std::vector<int64_t> a = axes.vec();
|
|
|
const int64_t ndim = self.ndimension();
|
|
|
std::vector<int64_t> perm;
|
|
|
|
|
|
for (const auto i : c10::irange(ndim)) {
|
|
|
auto it = std::find(a.begin(), a.end(), i);
|
|
|
if (it == a.end()) {
|
|
|
perm.push_back(i);
|
|
|
}
|
|
|
}
|
|
|
for (auto i : a) {
|
|
|
perm.push_back(i);
|
|
|
}
|
|
|
|
|
|
TORCH_CHECK((int64_t)perm.size() == ndim,
|
|
|
"duplicate or invalid axis in 'dim' argument for tensor with ndim==", ndim);
|
|
|
|
|
|
return self.permute(perm);
|
|
|
}
|
|
|
|
|
|
|
|
|
inline std::tuple<bool, bool> _parse_qr_mode(std::string_view mode) {
|
|
|
bool compute_q;
|
|
|
bool reduced;
|
|
|
if (mode == "reduced") {
|
|
|
compute_q = true;
|
|
|
reduced = true;
|
|
|
} else if (mode == "complete") {
|
|
|
compute_q = true;
|
|
|
reduced = false;
|
|
|
} else if (mode == "r") {
|
|
|
compute_q = false;
|
|
|
reduced = true;
|
|
|
} else {
|
|
|
TORCH_CHECK(false, "qr received unrecognized mode '", mode,
|
|
|
"' but expected one of 'reduced' (default), 'r', or 'complete'");
|
|
|
}
|
|
|
return std::make_tuple(compute_q, reduced);
|
|
|
}
|
|
|
|
|
|
|
|
|
inline std::tuple<DimVector, DimVector, int64_t> _compute_geometry_for_Q(
|
|
|
const Tensor& input,
|
|
|
bool reduced) {
|
|
|
int64_t m = input.size(-2), n = input.size(-1);
|
|
|
int64_t n_columns_q;
|
|
|
|
|
|
|
|
|
DimVector q_sizes(input.sizes());
|
|
|
if (!reduced && m > n) {
|
|
|
q_sizes[input.dim() - 1] = m;
|
|
|
n_columns_q = m;
|
|
|
} else {
|
|
|
q_sizes[input.dim() - 1] = n;
|
|
|
n_columns_q = std::min(m, n);
|
|
|
}
|
|
|
auto q_strides = batched_matrix_contiguous_strides(q_sizes, true);
|
|
|
return std::make_tuple(q_sizes, q_strides, n_columns_q);
|
|
|
}
|
|
|
|
|
|
inline bool svd_uses_cusolver(const Tensor& A) {
|
|
|
|
|
|
return A.is_cuda()
|
|
|
&& at::globalContext().hasCuSOLVER()
|
|
|
&& at::globalContext().linalgPreferredBackend() != at::LinalgBackend::Magma;
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
inline Tensor same_stride_to(const Tensor& original_tensor, const at::TensorOptions& options) {
|
|
|
auto strided_to = at::empty_strided(original_tensor.sizes(),
|
|
|
original_tensor.strides(),
|
|
|
options);
|
|
|
strided_to.copy_(original_tensor);
|
|
|
return strided_to;
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
inline std::vector<int64_t> create_dim_backshift_permutation(int64_t dim0, int64_t dim1, int64_t ndim) {
|
|
|
TORCH_CHECK(
|
|
|
(dim0 != dim1) && (dim0 < ndim) && (dim0 >= 0) && (dim1 < ndim) && (dim1 >= 0),
|
|
|
"duplicate or invalid dimensions");
|
|
|
std::vector<int64_t> permutation(ndim);
|
|
|
int64_t cur_permuted_dim = 0;
|
|
|
for (const auto dim_ind : c10::irange(ndim)) {
|
|
|
if ((dim_ind != dim0) && (dim_ind != dim1)) {
|
|
|
permutation[cur_permuted_dim++] = dim_ind;
|
|
|
}
|
|
|
}
|
|
|
permutation[cur_permuted_dim++] = dim0;
|
|
|
permutation[cur_permuted_dim] = dim1;
|
|
|
return permutation;
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
inline std::vector<int64_t> create_reverse_permutation(std::vector<int64_t> permutation) {
|
|
|
int64_t ndim = permutation.size();
|
|
|
std::vector<int64_t> reverse_permutation(ndim);
|
|
|
for (const auto dim_ind : c10::irange(ndim)) {
|
|
|
reverse_permutation[permutation[dim_ind]] = dim_ind;
|
|
|
}
|
|
|
return reverse_permutation;
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
inline int64_t computeLRWorkDim(const char jobz, int64_t m, int64_t n) {
|
|
|
auto mn = std::min(m, n);
|
|
|
auto mx = std::max(m, n);
|
|
|
if (jobz == 'N') {
|
|
|
#ifdef __APPLE__
|
|
|
|
|
|
return 7 * mn;
|
|
|
#else
|
|
|
|
|
|
return 5 * mn;
|
|
|
#endif
|
|
|
}
|
|
|
if (mx > 10 * mn) {
|
|
|
return 5 * mn * mn + 5 * mn;
|
|
|
}
|
|
|
return std::max(5 * mn * mn + 5 * mn, 2 * mx * mn + 2 * mn * mn + mn);
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
inline void checkUplo(const std::string_view uplo) {
|
|
|
|
|
|
char uplo_uppercase = static_cast<char>(std::toupper(static_cast<unsigned char>(uplo[0])));
|
|
|
TORCH_CHECK(uplo.size() == 1 && (uplo_uppercase == 'U' || uplo_uppercase == 'L'),
|
|
|
"Expected UPLO argument to be 'L' or 'U', but got ", uplo);
|
|
|
}
|
|
|
|
|
|
inline void checkSameDevice(const std::string& fn_name, Tensor result, Tensor input, const std::string& result_name = "result") {
|
|
|
TORCH_CHECK(
|
|
|
result.device() == input.device(),
|
|
|
fn_name,
|
|
|
": Expected ", result_name, " and input tensors to be on the same device, but got ",
|
|
|
result_name, " on ", result.device(), " and input on ", input.device());
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
inline void checkLinalgCompatibleDtype(const std::string& fn_name, Tensor result, Tensor input, const std::string& result_name = "result") {
|
|
|
bool can_cast = c10::canCast(input.scalar_type(), result.scalar_type());
|
|
|
TORCH_CHECK(
|
|
|
can_cast,
|
|
|
fn_name,
|
|
|
": Expected ", result_name, " to be safely castable from ", input.scalar_type(), " dtype, but got ",
|
|
|
result_name, " with dtype ", result.scalar_type());
|
|
|
}
|
|
|
|
|
|
|
|
|
inline void checkLinalgCompatibleDtype(const std::string& fn_name, ScalarType out_type, ScalarType result_type, const std::string& out_name = "result") {
|
|
|
bool can_cast = c10::canCast(result_type, out_type);
|
|
|
TORCH_CHECK(
|
|
|
can_cast,
|
|
|
fn_name,
|
|
|
": Expected ", out_name, " to be safely castable from ", result_type, " dtype, but got ",
|
|
|
out_name, " with dtype ", out_type);
|
|
|
}
|
|
|
|
|
|
inline void checkNotComplexTolerance(const Tensor& tol, const std::string_view f_name, const std::string_view tol_name) {
|
|
|
TORCH_CHECK(!at::isComplexType(tol.scalar_type()),
|
|
|
f_name, ": ", tol_name, " tensor of complex type is not supported. Got ", tol.scalar_type());
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
inline bool linalg_solve_is_vector_rhs(const Tensor& input, const Tensor& other) {
|
|
|
auto expected_batched_rhs_shape = SymIntArrayRef(input.sym_sizes().data(), input.dim() - 1);
|
|
|
bool vector_case = other.dim() == 1 || (input.dim() - 1 == other.dim() && other.sym_sizes().equals(expected_batched_rhs_shape));
|
|
|
return vector_case;
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
inline Tensor get_linear_indices(int64_t numel, IntArrayRef original_shape, IntArrayRef broadcast_shape) {
|
|
|
TensorOptions options = at::TensorOptions().dtype(at::kLong).device(at::kCPU);
|
|
|
return at::arange(numel, options).view(original_shape).broadcast_to(broadcast_shape).contiguous();
|
|
|
}
|
|
|
|
|
|
class BroadcastLinearIndices {
|
|
|
private:
|
|
|
Tensor linear_indices_;
|
|
|
bool is_broadcasting_;
|
|
|
|
|
|
public:
|
|
|
BroadcastLinearIndices(
|
|
|
int64_t numel,
|
|
|
IntArrayRef original_shape,
|
|
|
IntArrayRef broadcast_shape) : is_broadcasting_(!original_shape.equals(broadcast_shape)) {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if (is_broadcasting_) {
|
|
|
linear_indices_ =
|
|
|
get_linear_indices(numel, original_shape, broadcast_shape);
|
|
|
}
|
|
|
}
|
|
|
int64_t operator()(int64_t broadcast_linear_index) {
|
|
|
return is_broadcasting_
|
|
|
? linear_indices_.data_ptr<int64_t>()[broadcast_linear_index]
|
|
|
: broadcast_linear_index;
|
|
|
}
|
|
|
};
|
|
|
|
|
|
inline bool is_blas_compatible_column_major_order(const Tensor& input) {
|
|
|
IntArrayRef input_strides = input.strides();
|
|
|
IntArrayRef input_sizes = input.sizes();
|
|
|
auto ndim = input.dim();
|
|
|
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(ndim >= 2);
|
|
|
if (ndim > 3) {
|
|
|
return input.transpose(-2, -1).is_contiguous();
|
|
|
}
|
|
|
auto leading_dimension = input_strides[ndim - 1];
|
|
|
auto rows = input_sizes[ndim - 2];
|
|
|
bool batch_stride_compatible = true;
|
|
|
if (ndim == 3) {
|
|
|
auto cols = input_sizes[ndim - 1];
|
|
|
batch_stride_compatible =
|
|
|
input_strides[ndim - 3] >= leading_dimension * cols;
|
|
|
}
|
|
|
return (input_strides[ndim - 2] == 1) &&
|
|
|
(leading_dimension >= std::max<int64_t>(1, rows)) &&
|
|
|
batch_stride_compatible;
|
|
|
}
|
|
|
|
|
|
inline bool is_blas_compatible_row_major_order(const Tensor& input) {
|
|
|
IntArrayRef input_strides = input.strides();
|
|
|
IntArrayRef input_sizes = input.sizes();
|
|
|
auto ndim = input.dim();
|
|
|
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(ndim >= 2);
|
|
|
if (ndim > 3) {
|
|
|
return input.is_contiguous();
|
|
|
}
|
|
|
auto leading_dimension = input_strides[ndim - 2];
|
|
|
auto cols = input_sizes[ndim - 1];
|
|
|
bool batch_stride_compatible = true;
|
|
|
if (ndim == 3) {
|
|
|
auto rows = input_sizes[ndim - 2];
|
|
|
batch_stride_compatible =
|
|
|
input_strides[ndim - 3] >= leading_dimension * rows;
|
|
|
}
|
|
|
return (input_strides[ndim - 1] == 1) &&
|
|
|
(leading_dimension >= std::max<int64_t>(1, cols)) &&
|
|
|
batch_stride_compatible;
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|