File size: 5,602 Bytes
c1af2fa |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 |
// Copyright 2004-present Facebook. All Rights Reserved.
#pragma once
#include <ATen/core/Tensor.h>
#include <ATen/MemoryOverlap.h>
#include <ATen/Parallel.h>
#include <ATen/TensorIterator.h>
#include <ATen/cpu/vec/functional.h>
#include <ATen/cpu/vec/vec.h>
#include <c10/util/irange.h>
namespace at::native::detail {
struct InputMeta {
void* data_ptr;
int64_t inner_size;
InputMeta(const Tensor& t, int64_t dim, int64_t inner)
: data_ptr(t.data_ptr()), inner_size(t.sizes()[dim] * inner) {}
};
// This kernel is used by two TensorList types:
// 1. stack_serial_kernel uses at::ArrayRef<Tensor>
// 2. Static runtime calls this kernel directly (csrc/jit/runtime/static/ops.cpp) with
// ProcessedNodeInputWrapper.
// When making changes, make sure that they are compatible with both types!
template <typename scalar_t, typename TensorListType>
void stack_serial_kernel_impl(Tensor& result, TensorListType tensors, int64_t dim) {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
dim >= 0 && dim <= result.dim(),
"dim out of range in stack_serial_kernel_impl");
int64_t outer =
result.numel() / (result.sizes()[dim] * result.strides()[dim]);
scalar_t* result_data = result.data_ptr<scalar_t>();
int64_t ninputs = tensors.size();
std::vector<InputMeta> inputs;
inputs.reserve(ninputs);
for (const auto& tensor : tensors) {
inputs.emplace_back(tensor, dim, tensor.strides()[dim]);
}
using Vec = vec::Vectorized<scalar_t>;
scalar_t* result_ptr = result_data;
for (const auto i : c10::irange(outer)) {
for (const auto j : c10::irange(ninputs)) {
int64_t local_inner = inputs[j].inner_size;
scalar_t* input_ptr = (scalar_t*)(inputs[j].data_ptr) + i * local_inner;
if (local_inner < Vec::size()) {
for (const auto k : c10::irange(local_inner)) {
result_ptr[k] = input_ptr[k];
}
} else {
vec::map(
[](Vec x) { return x; }, result_ptr, input_ptr, local_inner);
}
result_ptr += local_inner;
}
}
}
// Checks to see whether native stack can be invoked under these conditions:
// - result and input tensors are contiguous
// - only one thread is used
// - no type promotion has to occur
// - tensors dtype is Double or Float
template <typename TensorListType>
bool can_use_native_serial_stack_impl(Tensor& result, TensorListType tensors, int64_t dim) {
TORCH_CHECK(!tensors.empty(), "expected a non-empty list of Tensors");
const Tensor& first_tensor = tensors[0];
// stack dimension should be in range [0,firstTensor.dim())
// dim == firstTensor.dim() is a valid input, but it is handled by default code path
// that uses unsqueeze
if (dim >= first_tensor.dim()) return false;
// Native stack doesn't apply any tensor is skipped.
if (first_tensor.numel() == 0 && first_tensor.dim() == 1) return false;
// there should be no type promotion
if (result.dtype() != first_tensor.dtype()) return false;
auto first_tensor_mem_format = first_tensor.suggest_memory_format();
ScalarType dtype = first_tensor.scalar_type();
if (!result.is_contiguous(first_tensor_mem_format)) {
return false;
}
// fast path only works for Double and Float
if (dtype != ScalarType::Double && dtype != ScalarType::Float) {
return false;
}
// check remainder of inputs
#ifndef STRIP_ERROR_MESSAGES
auto const &first_tensor_shape = first_tensor.sizes();
#endif
for (const auto i : c10::irange(1, tensors.size())) {
auto const &tensor = tensors[i];
TORCH_CHECK(tensors[i].sizes() == first_tensor.sizes(),
"stack expects each tensor to be equal size, but got ", first_tensor_shape,
" at entry 0 and ", tensor.sizes(), " at entry ", i);
// every tensor must be contiguous
// tensor sizes and strides must be the same
// there should be no type promotion
if (!tensor.is_contiguous(first_tensor_mem_format) ||
tensor.strides() != first_tensor.strides() ||
tensor.dtype() != dtype) {
return false;
}
}
// fast native stack should only be used when it is not worth using multiple threads
// or there is only one thread. Note that we aren't checking result.numel() here because
// it may not have been resized and we want to defer that cost till later.
int64_t numel_in_stack = first_tensor.numel() * tensors.size();
return numel_in_stack < at::internal::GRAIN_SIZE || at::get_num_threads() == 1;
}
template <typename TensorListType, bool should_skip_overlap_check>
struct CanUseNativeSerialStack;
template <typename TensorListType>
struct CanUseNativeSerialStack<TensorListType, false> {
static bool call(Tensor& result, TensorListType tensors, int64_t dim) {
// Inputs cannot alias the output tensor
for (const auto i : c10::irange(tensors.size())) {
auto lap = at::get_overlap_status(result, tensors[i]);
TORCH_CHECK(lap != at::MemOverlapStatus::Partial &&
lap != at::MemOverlapStatus::Full, 0,
"unsupported operation: the input tensors cannot refer to any of the "
"output memory locations. Found overlap in input tensor ", i);
}
return can_use_native_serial_stack_impl(result, tensors, dim);
}
};
template <typename TensorListType>
struct CanUseNativeSerialStack<TensorListType, true> {
static bool call(Tensor& result, TensorListType tensors, int64_t dim) {
return can_use_native_serial_stack_impl(result, tensors, dim);
}
};
} // namespace at::native::detail
|