File size: 15,476 Bytes
c1af2fa |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 |
#pragma once
#include <ATen/Device.h>
#include <ATen/Dispatch.h>
#include <ATen/ScalarType.h>
#include <ATen/core/Tensor.h>
#include <ATen/native/utils/ParamsHash.h>
#include <c10/util/Exception.h>
#include <c10/util/irange.h>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/NativeFunctions.h>
#else
#include <ATen/ops/result_type_native.h>
#endif
#include <unordered_map>
#include <vector>
namespace at::native {
namespace {
// Check if tensor list has either a boolean tensor or a integer tensor
inline bool has_integral_tensor(TensorList tensors, const bool includeBool) {
return std::any_of(
tensors.begin(), tensors.end(), [&includeBool](const auto& t) {
return at::isIntegralType(t.scalar_type(), includeBool);
});
}
// check if tensor list has bool tensors
inline bool has_bool_tensor(TensorList tensors) {
return std::any_of(tensors.begin(), tensors.end(), [](const auto& t) -> bool {
return t.scalar_type() == ScalarType::Bool;
});
}
// Check foreach API restrictions
// - Tensor lists must be non-empty.
// - All TensorLists and ScalarLists must have the same number of elements.
// - Corresponding tensors must have the same size.
inline void check_foreach_api_restrictions(TensorList tensors) {
TORCH_CHECK(!tensors.empty(), "Tensor list must have at least one tensor.");
}
inline void check_foreach_api_restrictions(
TensorList tensors,
ArrayRef<Scalar> scalars) {
check_foreach_api_restrictions(tensors);
TORCH_CHECK(
tensors.size() == scalars.size(),
"Tensor list must have same number of elements as scalar list.");
}
inline void check_foreach_api_restrictions(
TensorList tensors1,
TensorList tensors2) {
TORCH_CHECK(!tensors1.empty(), "Tensor list must have at least one tensor.");
TORCH_CHECK(!tensors2.empty(), "Tensor list must have at least one tensor.");
TORCH_CHECK(
tensors1.size() == tensors2.size(),
"Tensor lists must have the same number of tensors, got ",
tensors1.size(),
" and ",
tensors2.size());
}
inline void check_foreach_api_restrictions(
TensorList tensors1,
TensorList tensors2,
TensorList tensors3) {
TORCH_CHECK(!tensors1.empty(), "Tensor list must have at least one tensor.");
TORCH_CHECK(!tensors2.empty(), "Tensor list must have at least one tensor.");
TORCH_CHECK(!tensors3.empty(), "Tensor list must have at least one tensor.");
TORCH_CHECK(
tensors1.size() == tensors2.size(),
"Tensor lists must have the same number of tensors, got ",
tensors1.size(),
" and ",
tensors2.size());
TORCH_CHECK(
tensors1.size() == tensors3.size(),
"Tensor lists must have the same number of tensors, got ",
tensors1.size(),
" and ",
tensors3.size());
}
inline void check_foreach_api_restrictions(
TensorList tensors1,
TensorList tensors2,
TensorList tensors3,
ArrayRef<Scalar> scalars) {
check_foreach_api_restrictions(tensors1, tensors2, tensors3);
TORCH_CHECK(
tensors1.size() == scalars.size(),
"Tensor list must have same number of elements as scalar list, got ",
tensors1.size(),
" and ",
scalars.size());
}
inline void check_foreach_api_restrictions(
TensorList tensors1,
TensorList tensors2,
ArrayRef<Scalar> scalars) {
check_foreach_api_restrictions(tensors1, tensors2);
TORCH_CHECK(
tensors1.size() == scalars.size(),
"Tensor list must have same number of elements as scalar list, got ",
tensors1.size(),
" and ",
scalars.size());
}
// Helper function called in check_fast_path_restrictions to check whether all
// corresponding tensors (aligning in index across the tensorLists) share the
// same device and dtype.
inline bool _check_tensors_share_device_and_dtype(
ArrayRef<TensorList> tensorLists,
const bool skip_dtype_check = false) {
const auto expected_dtype = tensorLists[0][0].dtype();
const auto expected_device = tensorLists[0][0].device();
auto is_tensor_okay = [&](const Tensor& tensor) {
return (skip_dtype_check || tensor.dtype() == expected_dtype) &&
tensor.device() == expected_device && tensor.layout() == at::kStrided &&
tensor.is_non_overlapping_and_dense();
};
for (const auto& tensorList : tensorLists) {
for (const auto& tensor : tensorList) {
if (!is_tensor_okay(tensor)) {
return false;
}
}
}
return true;
}
// Helper function called in check_fast_path_restrictions to check if
// corresponding tensors in tensor lists have the same sizes and strides.
inline bool _check_tensors_share_sizes_and_strides(
ArrayRef<TensorList> tensorLists) {
auto is_diff_stride = [](const IntArrayRef& size,
const IntArrayRef& left_stride,
const IntArrayRef& right_stride) -> bool {
const size_t size_size = size.size();
for (const auto dim : c10::irange(size_size)) {
if (size[dim] == 1)
continue;
if (left_stride[dim] != right_stride[dim]) {
return true;
}
}
return false;
};
for (const auto i : c10::irange(1, tensorLists.size())) {
for (const auto j : c10::irange(tensorLists[0].size())) {
if (tensorLists[0][j].sizes() != tensorLists[i][j].sizes() ||
is_diff_stride(
tensorLists[0][j].sizes(),
tensorLists[0][j].strides(),
tensorLists[i][j].strides())) {
return false;
}
}
}
return true;
}
// Helper function called in check_fast_path_restrictions to check whether
// all tensors type promote properly with the scalars in scalarList. This
// function assumes that _check_tensors_share_device_and_dtype has already been
// called so that all corresponding tensors in tensorLists have the same dtype.
// Then, it is sufficient to check the type promotion with just one tensorList.
inline bool _check_tensors_do_type_promotion_with_scalars(
TensorList tensorList,
ArrayRef<Scalar> scalarList = {},
bool does_op_promote_integer_inputs_to_float = false) {
for (const auto i : c10::irange(tensorList.size())) {
// For division, integer inputs will result in float.
if (does_op_promote_integer_inputs_to_float) {
if (at::isIntegralType(
tensorList[i].scalar_type(), /*includeBool*/ true)) {
return false;
}
}
if (!scalarList.empty()) {
const auto& scalar =
scalarList.size() == 1 ? scalarList[0] : scalarList[i];
const auto& tensor = tensorList[i];
// note(mkozuki): This check might be responsible for
// `_foreach_add(bool_tensors, bool_tensors)` being pushed to slow path.
if (tensor.scalar_type() != at::native::result_type(scalar, tensor)) {
return false;
}
}
}
return true;
}
// To go via 'fast' path, several conditions must be satisfied
// - All tensors in all lists must have the same dtype.
// - All tensors must be on the same device
// - All tensors must have strided layout
// - All tensors must be non-overlapping and dense
// - Resulting tensor must have the same dtype as the input one
// [note: what's ``does_op_promote_integer_inputs_to_float=true``?]
// ``does_op_promote_integer_inputs_to_float=true`` means that the result of
// the op will be float even if inputs are integer or boolean, which
// currently fast path does not support. In short, this flag, when
// turned on, gatekeeps the op from going down the fastpath.
// Please, make sure to call check_foreach_api_restrictions before calling this
// method. There is a set of preconditions that have to be satisfied.
inline bool check_fast_path_restrictions(
ArrayRef<TensorList> tensorLists,
ArrayRef<Scalar> scalarList = {},
bool does_op_promote_integer_inputs_to_float = false) {
return _check_tensors_share_device_and_dtype(tensorLists) &&
_check_tensors_share_sizes_and_strides(tensorLists) &&
_check_tensors_do_type_promotion_with_scalars(
tensorLists[0],
scalarList,
does_op_promote_integer_inputs_to_float);
}
inline std::vector<c10::Scalar> convert_tensor_to_scalar_list(
const Tensor& scalarList_,
int64_t expect_length) {
std::vector<c10::Scalar> scalarList;
TORCH_CHECK(
scalarList_.device() == c10::kCPU,
"Expected scalars to be on CPU, got ",
scalarList_.device(),
" instead.");
TORCH_CHECK(
scalarList_.is_contiguous(), "Expected scalars to be contiguous.");
TORCH_CHECK(
scalarList_.dim() == 1,
"Expected packed scalar Tensor to be of dimension 1. Got ",
scalarList_.dim(),
" instead.");
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(
kComplexHalf,
kHalf,
kBool,
kBFloat16,
scalarList_.scalar_type(),
"convert_tensor_to_scalar_list",
[&]() {
const scalar_t* scalar_data = scalarList_.const_data_ptr<scalar_t>();
TORCH_CHECK(
(expect_length == scalarList_.size(0)),
"Expected length of scalars to match input of length ",
expect_length,
" but got ",
scalarList_.size(0),
" instead.");
for (int64_t i = 0; i < scalarList_.size(0); i++) {
scalarList.emplace_back(scalar_data[i]);
}
});
return scalarList;
}
// see: [note: what's ``does_op_promote_integer_inputs_to_float=true``?]
inline bool can_use_fast_route(
ArrayRef<TensorList> tensorLists,
ArrayRef<Scalar> scalarList = {},
bool does_op_promote_integer_inputs_to_float = false) {
return check_fast_path_restrictions(
tensorLists, scalarList, does_op_promote_integer_inputs_to_float);
}
// see: [note: what's ``does_op_promote_integer_inputs_to_float=true``?]
inline bool can_use_fast_route(
TensorList tensors1,
TensorList tensors2,
bool does_op_promote_integer_inputs_to_float = false) {
return can_use_fast_route(
{tensors1, tensors2}, {}, does_op_promote_integer_inputs_to_float);
}
using DeviceDtypeKey = std::pair<at::Device, at::ScalarType>;
using IndicesT = std::vector<size_t>;
using nested_optional_tensorvec_t =
std::vector<std::vector<std::optional<at::Tensor>>>;
using TensorsAndIndicesT = std::pair<nested_optional_tensorvec_t, IndicesT>;
using FlatMap = std::unordered_map<
DeviceDtypeKey,
TensorsAndIndicesT,
ParamsHash<DeviceDtypeKey>>;
inline FlatMap _group_tensors_by_first_tensors_device_and_dtype(
const nested_optional_tensorvec_t& nested_tensorlist,
const bool with_indices) {
FlatMap grouped_tensors_with_indices;
TORCH_CHECK(!nested_tensorlist.empty());
TORCH_CHECK(!nested_tensorlist[0].empty());
const auto num_lists = nested_tensorlist.size();
const auto num_tensors = nested_tensorlist[0].size();
TORCH_CHECK(std::all_of(
nested_tensorlist.cbegin(),
nested_tensorlist.cend(),
[&](const auto& tensorlist) -> bool {
// note(crcrpar): Allow empty tensorlists following
// ref:
// https://github.com/pytorch/pytorch/blob/85885301fd3c6adb8b9dc3cf7afadf6945566684/torch/utils/_foreach_utils.py#L21-L24
return tensorlist.size() == num_tensors || tensorlist.size() == 0;
}));
for (const auto& tensor_index : c10::irange(num_tensors)) {
const auto key = [&]() -> DeviceDtypeKey {
const auto t = nested_tensorlist[0][tensor_index];
TORCH_CHECK(
t.has_value(),
"Tensors of the first list of nested Tensor lists are supposed to be defined but ",
"the ",
tensor_index,
"-th Tensor is not.");
return {t->device(), t->scalar_type()};
}();
TORCH_CHECK(
std::all_of(
nested_tensorlist.cbegin(),
nested_tensorlist.cend(),
[&](const auto& tensorlist) -> bool {
if (tensorlist.size() == 0) {
return true;
}
const auto& tensor = tensorlist[tensor_index];
// note(crcrpar): Currently the scope of this function is
// optimizers so there could be `state_steps` and other scalars
// whose elements are float tensors no matter what the parameter's
// dtype is.
if (!tensor.has_value()) {
return true;
} else {
const auto s = tensor->scalar_type();
const auto d = tensor->device();
// Note: `step` or `state_step` is float32 by default.
if (key.first == d) {
return key.second == s || s == at::ScalarType::Float ||
s == at::ScalarType::Double;
} else if (d.is_cpu()) {
// note(crcrpar): There are some test cases (e.g.
// TestOptim::test_adam) where state_steps are on CPU and the
// others are on CUDA. Currently a state_step Tensor has the
// dtype of float.
return s == at::ScalarType::Float ||
s == at::ScalarType::Double;
} else {
return false;
}
}
}),
"Tensors of the same index must be on the same device and the same dtype except `step` tensors that can be CPU and float32/64 notwithstanding");
if (!grouped_tensors_with_indices.count(key)) {
grouped_tensors_with_indices.insert(
{key,
TensorsAndIndicesT{
[&]() -> nested_optional_tensorvec_t {
nested_optional_tensorvec_t nested_tensorvec;
nested_tensorvec.reserve(num_lists);
for (const auto& i : c10::irange(num_lists)) {
std::vector<std::optional<at::Tensor>> tensors;
if (!nested_tensorlist[i].empty()) {
// NB: num_tensors is the max possible length for any of
// the inner lists of tensor references. Reserving the max
// trades memory for perf. This should not have significant
// impact.
tensors.reserve(num_tensors);
}
nested_tensorvec.emplace_back(tensors);
}
return nested_tensorvec;
}(),
[&]() -> IndicesT {
if (!with_indices) {
return {};
} else {
IndicesT indices;
indices.reserve(num_tensors);
return indices;
}
}()}});
}
for (const auto& list_index : c10::irange(num_lists)) {
if (!nested_tensorlist[list_index].empty()) {
grouped_tensors_with_indices[key].first[list_index].emplace_back(
nested_tensorlist[list_index][tensor_index]);
}
}
if (with_indices) {
grouped_tensors_with_indices[key].second.emplace_back(tensor_index);
}
}
return grouped_tensors_with_indices;
}
} // namespace
} // namespace at::native
|