|
|
#include <c10/util/Exception.h>
|
|
|
#include <utility>
|
|
|
|
|
|
namespace at {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
inline std::pair<int64_t, int64_t> collapse_dims(
|
|
|
T* sizes,
|
|
|
T* strides,
|
|
|
int64_t dims,
|
|
|
const int excludeDim = -1) {
|
|
|
TORCH_CHECK(
|
|
|
excludeDim >= -1 && excludeDim < dims,
|
|
|
"expected excluded dim between -1 and dims - 1");
|
|
|
|
|
|
int64_t stopDim = (excludeDim == -1) ? dims : excludeDim;
|
|
|
int64_t newIndex = -1;
|
|
|
int64_t oldIndex = 0;
|
|
|
int64_t remappedExcludedDim = -1;
|
|
|
|
|
|
while (oldIndex < dims) {
|
|
|
|
|
|
for (; oldIndex < stopDim; ++oldIndex) {
|
|
|
if (sizes[oldIndex] == 1) {
|
|
|
continue;
|
|
|
}
|
|
|
|
|
|
++newIndex;
|
|
|
sizes[newIndex] = sizes[oldIndex];
|
|
|
strides[newIndex] = strides[oldIndex];
|
|
|
++oldIndex;
|
|
|
break;
|
|
|
}
|
|
|
|
|
|
|
|
|
for (; oldIndex < stopDim; ++oldIndex) {
|
|
|
if (sizes[oldIndex] == 1) {
|
|
|
continue;
|
|
|
}
|
|
|
|
|
|
if (strides[newIndex] == sizes[oldIndex] * strides[oldIndex]) {
|
|
|
sizes[newIndex] *= sizes[oldIndex];
|
|
|
strides[newIndex] = strides[oldIndex];
|
|
|
} else {
|
|
|
++newIndex;
|
|
|
sizes[newIndex] = sizes[oldIndex];
|
|
|
strides[newIndex] = strides[oldIndex];
|
|
|
}
|
|
|
}
|
|
|
|
|
|
|
|
|
if (oldIndex != dims) {
|
|
|
|
|
|
++newIndex;
|
|
|
sizes[newIndex] = sizes[oldIndex];
|
|
|
strides[newIndex] = strides[oldIndex];
|
|
|
remappedExcludedDim = newIndex;
|
|
|
|
|
|
|
|
|
++oldIndex;
|
|
|
stopDim = dims;
|
|
|
}
|
|
|
}
|
|
|
|
|
|
|
|
|
if (newIndex == -1 || (newIndex == 0 && sizes[0] == 1)) {
|
|
|
dims = 1;
|
|
|
sizes[0] = 1;
|
|
|
strides[0] = 1;
|
|
|
|
|
|
return std::pair<int64_t, int64_t>(0, 1);
|
|
|
}
|
|
|
|
|
|
dims = newIndex + 1;
|
|
|
return std::pair<int64_t, int64_t>(remappedExcludedDim, dims);
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|