File size: 13,487 Bytes
c1af2fa |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 |
#pragma once
#include <ATen/OpMathType.h>
#include <ATen/Parallel.h>
#include <ATen/cpu/vec/functional.h>
#include <ATen/cpu/vec/vec.h>
#include <c10/util/irange.h>
#include <algorithm>
#include <cmath>
#include <cstdint>
#include <limits>
#include <memory>
#include <type_traits>
namespace at::native {
inline namespace CPU_CAPABILITY {
template <typename scalar_t>
int64_t vec_log_softmax_lastdim_chunk_size(int64_t grain_size, int64_t outer_size, int64_t dim_size) {
// Coincidentally, at::internal::GRAIN_SIZE is 32768, which is equal to the
// size of L1D cache on many processors. Some processors have 48 KB L1D cache
// nowadays, so maybe in the future, we can leverage the knowledge of a
// machine's L1D cache size.
int64_t MAX_CHUNK_SIZE = std::max<int64_t>(
1,
grain_size / (sizeof(scalar_t) * dim_size));
return std::min<int64_t>(MAX_CHUNK_SIZE, outer_size);
}
template <typename scalar_t>
void serial_vec_log_softmax_lastdim_range(
const scalar_t* input_data_base,
scalar_t* output_data_base,
int64_t dim_size,
int64_t chunk_size,
int64_t begin,
int64_t end) {
if (end <= begin) {
return;
}
using Vec = vec::Vectorized<vec::vec_scalar_t<scalar_t>>;
// MSVC requires such a declaration of dynamic arrays
// Source: https://stackoverflow.com/a/33423538
auto tmp_sum_scalar = std::make_unique<scalar_t[]>(chunk_size);
auto max_input_arr = std::make_unique<scalar_t[]>(chunk_size);
for (int64_t ii = begin; ii < end; ii += chunk_size) {
int64_t loop_end = chunk_size;
if (ii + chunk_size > end) {
loop_end = end - ii;
}
for (const auto j : c10::irange(loop_end)) {
int64_t i = ii + j;
const scalar_t* input_data = input_data_base + i * dim_size;
max_input_arr[j] = vec::reduce_all<scalar_t>(
[](Vec& x, Vec& y) { return vec::maximum(x, y); },
input_data,
dim_size);
}
for (const auto j : c10::irange(loop_end)) {
int64_t i = ii + j;
const scalar_t* input_data = input_data_base + i * dim_size;
scalar_t max_input = max_input_arr[j];
tmp_sum_scalar[j] = vec::map_reduce_all<scalar_t>(
[max_input](Vec x) { return (x - Vec(max_input)).exp(); },
[](Vec x, Vec y) { return x + y; },
input_data,
dim_size);
}
// See [Note AVX-SSE transitions] for why this should call the
// vectorized version (aside from perf improvements).
vec::map(
[](Vec x) { return x.log(); },
tmp_sum_scalar.get(),
tmp_sum_scalar.get(),
loop_end);
for (const auto j : c10::irange(loop_end)) {
int64_t i = ii + j;
const scalar_t* input_data = input_data_base + i * dim_size;
scalar_t* output_data = output_data_base + i * dim_size;
scalar_t tmp_sum = tmp_sum_scalar[j];
scalar_t max_input = max_input_arr[j];
// It's necessary to keep the order of the operations below.
// In some cases that input is large digits and the difference
// is small, if we compute `max_input` plus `tmp_sum` before,
// there would be a numerical problem. See an example in
// https://github.com/pytorch/pytorch/issues/11752#issuecomment-422883379
vec::map(
[tmp_sum, max_input](Vec x) {
return x - Vec(max_input) - Vec(tmp_sum);
},
output_data,
input_data,
dim_size);
}
}
}
// Can't include ATen/Parallel.h.
// TODO: find a way to have only one copy of divup.
inline int64_t divup(int64_t x, int64_t y) {
return (x + y - 1) / y;
}
template <typename scalar_t, int64_t BLOCK_SIZE = 128 * 1024>
std::pair<int64_t,int64_t> vec_logsoftmax_chunk_size_and_num_chunks(int64_t inner_size, int64_t dim_size) {
using Vec = vec::Vectorized<scalar_t>;
int64_t MAX_CHUNK_SIZE = std::max<int64_t>(BLOCK_SIZE / dim_size / sizeof(scalar_t), Vec::size());
MAX_CHUNK_SIZE = MAX_CHUNK_SIZE / Vec::size() * Vec::size();
int64_t CHUNK_SIZE = std::min<int64_t>(MAX_CHUNK_SIZE, inner_size);
int64_t num_chunks = divup(inner_size, CHUNK_SIZE);
return {CHUNK_SIZE, num_chunks};
}
template <typename scalar_t>
std::enable_if_t<std::is_same_v<scalar_t, at::opmath_type<scalar_t>>, void>
serial_vec_logsoftmax_range(
const scalar_t* input_data_base,
scalar_t* output_data_base,
int64_t inner_size,
int64_t chunk_size,
int64_t num_chunks,
int64_t dim_size,
int64_t begin,
int64_t end) {
using Vec = vec::Vectorized<scalar_t>;
// thread local temp buffer which holds vertical reduction result: max and sum.
auto buffer = std::make_unique<scalar_t []>(chunk_size * 2);
scalar_t* input_max_data = buffer.get();
scalar_t* tmp_sum_data = buffer.get() + chunk_size;
for (int64_t i = begin; i < end; i++) {
int64_t outer_idx = i / num_chunks;
int64_t k = i % num_chunks;
int64_t inner_idx_begin = k * chunk_size;
int64_t size = std::min(chunk_size, inner_size - inner_idx_begin);
// init
Vec zero_vec = Vec(scalar_t(0));
Vec min_vec = Vec(-std::numeric_limits<scalar_t>::infinity());
int64_t d0 = 0;
for (; d0 < size - (size % Vec::size()); d0 += Vec::size()) {
min_vec.store(input_max_data + d0);
zero_vec.store(tmp_sum_data + d0);
}
for (; d0 < size; d0++) {
input_max_data[d0] = -std::numeric_limits<scalar_t>::infinity();
tmp_sum_data[d0] = scalar_t(0);
}
// compute max
for (int64_t dim_idx = 0; dim_idx < dim_size; dim_idx++) {
const scalar_t* input_ptr = input_data_base + outer_idx * dim_size * inner_size
+ dim_idx * inner_size + inner_idx_begin;
int64_t d1 = 0;
for (; d1 < size - (size % Vec::size()); d1 += Vec::size()) {
Vec data_vec = Vec::loadu(input_ptr + d1);
Vec max_vec = Vec::loadu(input_max_data + d1);
max_vec = Vec::blendv(max_vec, data_vec, data_vec > max_vec);
max_vec.store(input_max_data + d1);
}
for (; d1 < size; d1++) {
scalar_t data_val = input_ptr[d1];
scalar_t max_val = input_max_data[d1];
input_max_data[d1] = data_val > max_val ? data_val : max_val;
}
}
// compute sum of (x - max).exp()
for (int64_t dim_idx = 0; dim_idx < dim_size; dim_idx++) {
const scalar_t* input_ptr = input_data_base + outer_idx * dim_size * inner_size
+ dim_idx * inner_size + inner_idx_begin;
int64_t d2 = 0;
for (; d2 < size - (size % Vec::size()); d2 += Vec::size()) {
Vec data_vec = Vec::loadu(input_ptr + d2);
Vec sum_vec = Vec::loadu(tmp_sum_data + d2);
Vec max_vec = Vec::loadu(input_max_data + d2);
sum_vec += (data_vec - max_vec).exp();
sum_vec.store(tmp_sum_data + d2);
}
for (; d2 < size; d2++) {
scalar_t data_val = input_ptr[d2];
scalar_t max_val = input_max_data[d2];
tmp_sum_data[d2] += std::exp(data_val - max_val);
}
}
// apply log
vec::map([](Vec x) { return x.log(); }, tmp_sum_data, tmp_sum_data, size);
// compute x - max - sum
for (int64_t dim_idx = 0; dim_idx < dim_size; dim_idx++) {
int64_t offset = outer_idx * dim_size * inner_size + dim_idx * inner_size + inner_idx_begin;
const scalar_t* input_ptr = input_data_base + offset;
scalar_t* output_ptr = output_data_base + offset;
int64_t d3 = 0;
for (; d3 < size - (size % Vec::size()); d3 += Vec::size()) {
Vec data_vec = Vec::loadu(input_ptr + d3);
Vec max_vec = Vec::loadu(input_max_data + d3);
Vec sum_vec = Vec::loadu(tmp_sum_data + d3);
Vec out_vec = data_vec - max_vec - sum_vec;
out_vec.store(output_ptr + d3);
}
for (; d3 < size; d3++) {
output_ptr[d3] = input_ptr[d3] - input_max_data[d3] - tmp_sum_data[d3];
}
}
}
}
template <typename scalar_t>
std::enable_if_t<!std::is_same_v<scalar_t, at::opmath_type<scalar_t>>, void>
serial_vec_logsoftmax_range(
const scalar_t* input_data_base,
scalar_t* output_data_base,
int64_t inner_size,
int64_t chunk_size,
int64_t num_chunks,
int64_t dim_size,
int64_t begin,
int64_t end) {
using Vec = vec::Vectorized<scalar_t>;
using fVec = vec::Vectorized<float>;
auto buffer = std::make_unique<float []>(chunk_size * 2);
float* input_max_data = buffer.get();
float* tmp_sum_data = buffer.get() + chunk_size;
// thread local buffer that holds input data in float32 to save next 2 dtype conversion
auto input_buffer = std::make_unique<float []>(dim_size * chunk_size);
float* input_buffer_data = input_buffer.get();
// init
for (int64_t i = begin; i < end; i++) {
int64_t outer_idx = i / num_chunks;
int64_t k = i % num_chunks;
int64_t inner_idx_begin = k * chunk_size;
int64_t size = std::min(chunk_size, inner_size - inner_idx_begin);
fVec zero_fvec = fVec(float(0));
fVec min_fvec = fVec(-std::numeric_limits<float>::infinity());
int64_t d0 = 0;
for (; d0 < size - (size % Vec::size()); d0 += Vec::size()) {
min_fvec.store(input_max_data + d0);
min_fvec.store(input_max_data + d0 + fVec::size());
zero_fvec.store(tmp_sum_data + d0);
zero_fvec.store(tmp_sum_data + d0 + fVec::size());
}
for (; d0 < size; d0++) {
input_max_data[d0] = -std::numeric_limits<float>::infinity();
tmp_sum_data[d0] = float(0);
}
// compute max
for (int64_t dim_idx = 0; dim_idx < dim_size; dim_idx++) {
const scalar_t* input_ptr = input_data_base + outer_idx * dim_size * inner_size
+ dim_idx * inner_size + inner_idx_begin;
float* input_buffer_ptr = input_buffer_data + dim_idx * chunk_size;
int64_t d1 = 0;
for (; d1 < size - (size % Vec::size()); d1 += Vec::size()) {
Vec data_vec = Vec::loadu(input_ptr + d1);
auto [data_fvec0, data_fvec1] = vec::convert_to_float<scalar_t>(data_vec);
fVec max_fvec0 = fVec::loadu(input_max_data + d1);
fVec max_fvec1 = fVec::loadu(input_max_data + d1 + fVec::size());
max_fvec0 = fVec::blendv(max_fvec0, data_fvec0, data_fvec0 > max_fvec0);
max_fvec1 = fVec::blendv(max_fvec1, data_fvec1, data_fvec1 > max_fvec1);
max_fvec0.store(input_max_data + d1);
max_fvec1.store(input_max_data + d1 + fVec::size());
// cache the 'converted' float input
data_fvec0.store(input_buffer_ptr + d1);
data_fvec1.store(input_buffer_ptr + d1 + fVec::size());
}
for (; d1 < size; d1++) {
float data_val = float(input_ptr[d1]);
float max_val = input_max_data[d1];
input_max_data[d1] = data_val > max_val ? data_val : max_val;
input_buffer_ptr[d1] = data_val;
}
}
// compute sum of (x - max).exp()
for (int64_t dim_idx = 0; dim_idx < dim_size; dim_idx++) {
float* input_buffer_ptr = input_buffer_data + dim_idx * chunk_size;
int64_t d2 = 0;
for (; d2 < size - (size % Vec::size()); d2 += Vec::size()) {
fVec data_fvec0 = fVec::loadu(input_buffer_ptr + d2);
fVec data_fvec1 = fVec::loadu(input_buffer_ptr + d2 + fVec::size());
fVec sum_fvec0 = fVec::loadu(tmp_sum_data + d2);
fVec sum_fvec1 = fVec::loadu(tmp_sum_data + d2 + fVec::size());
fVec max_fvec0 = fVec::loadu(input_max_data + d2);
fVec max_fvec1 = fVec::loadu(input_max_data + d2 + fVec::size());
sum_fvec0 += (data_fvec0 - max_fvec0).exp();
sum_fvec1 += (data_fvec1 - max_fvec1).exp();
sum_fvec0.store(tmp_sum_data + d2);
sum_fvec1.store(tmp_sum_data + d2 + fVec::size());
}
for (; d2 < size; d2++) {
float data_val = input_buffer_ptr[d2];
float max_val = input_max_data[d2];
tmp_sum_data[d2] += std::exp(data_val - max_val);
}
}
// apply log
vec::map([](fVec x) { return x.log(); }, tmp_sum_data, tmp_sum_data, size);
// compute x - max - sum
for (int64_t dim_idx = 0; dim_idx < dim_size; dim_idx++) {
float* input_buffer_ptr = input_buffer_data + dim_idx * chunk_size;
scalar_t* output_ptr = output_data_base + outer_idx * dim_size * inner_size
+ dim_idx * inner_size + inner_idx_begin;
int64_t d3 = 0;
for (; d3 < size - (size % Vec::size()); d3 += Vec::size()) {
fVec data_fvec0 = fVec::loadu(input_buffer_ptr + d3);
fVec data_fvec1 = fVec::loadu(input_buffer_ptr + d3 + fVec::size());
fVec max_fvec0 = fVec::loadu(input_max_data + d3);
fVec max_fvec1 = fVec::loadu(input_max_data + d3 + fVec::size());
fVec sum_fvec0 = fVec::loadu(tmp_sum_data + d3);
fVec sum_fvec1 = fVec::loadu(tmp_sum_data + d3 + fVec::size());
fVec out_fvec0 = data_fvec0 - max_fvec0 - sum_fvec0;
fVec out_fvec1 = data_fvec1 - max_fvec1 - sum_fvec1;
Vec out_vec = vec::convert_from_float<scalar_t>(out_fvec0, out_fvec1);
out_vec.store(output_ptr + d3);
}
for (; d3 < size; d3++) {
output_ptr[d3] = scalar_t(input_buffer_ptr[d3] - input_max_data[d3] - tmp_sum_data[d3]);
}
}
}
} // namespace CPU_CAPABILITY
}} // namespace at::native
|