|
|
#pragma once
|
|
|
|
|
|
|
|
|
|
|
|
#ifdef _WIN32
|
|
|
#define _USE_MATH_DEFINES
|
|
|
#include <cmath>
|
|
|
#endif
|
|
|
|
|
|
#include <ATen/cpu/vec/vec.h>
|
|
|
#include <c10/util/BFloat16.h>
|
|
|
|
|
|
namespace at::native {
|
|
|
inline namespace CPU_CAPABILITY {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
template <typename ParamT, typename MathT=ParamT>
|
|
|
auto get_scalar_elu_elementwise_func(MathT alpha, MathT scale, MathT input_scale) {
|
|
|
const auto negcoef = alpha * scale;
|
|
|
const auto poscoef = scale;
|
|
|
const auto negiptcoef = input_scale;
|
|
|
return [negcoef, negiptcoef, poscoef](ParamT a) -> ParamT {
|
|
|
return MathT(a) < MathT(0)
|
|
|
? std::expm1(MathT(a) * negiptcoef) * negcoef
|
|
|
: MathT(a) * poscoef;
|
|
|
};
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
template <typename T, std::enable_if_t<!c10::is_reduced_floating_point_v<T>, bool> = true>
|
|
|
auto get_vectorized_elu_elementwise_func(T alpha, T scale, T input_scale) {
|
|
|
const vec::Vectorized<T> negcoef_vec(alpha * scale);
|
|
|
const vec::Vectorized<T> poscoef_vec(scale);
|
|
|
const vec::Vectorized<T> negiptcoef_vec(input_scale);
|
|
|
const vec::Vectorized<T> zero_vec(static_cast<T>(0));
|
|
|
return [negcoef_vec, poscoef_vec, negiptcoef_vec, zero_vec](vec::Vectorized<T> a) -> vec::Vectorized<T> {
|
|
|
const auto cmp = a >= zero_vec;
|
|
|
if (!cmp.zero_mask()) {
|
|
|
return a * poscoef_vec;
|
|
|
} else {
|
|
|
return vec::Vectorized<T>::blendv((a * negiptcoef_vec).expm1() * negcoef_vec, a * poscoef_vec, cmp);
|
|
|
}
|
|
|
};
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
template <typename T, std::enable_if_t<c10::is_reduced_floating_point_v<T>, bool> = true>
|
|
|
auto get_vectorized_elu_elementwise_func(float alpha, float scale, float input_scale) {
|
|
|
|
|
|
const auto float_func = get_vectorized_elu_elementwise_func<float>(alpha, scale, input_scale);
|
|
|
return [float_func](vec::Vectorized<T> a) -> vec::Vectorized<T> {
|
|
|
auto [a0, a1] = vec::convert_to_float<T>(a);
|
|
|
auto res0 = float_func(a0);
|
|
|
auto res1 = float_func(a1);
|
|
|
return vec::convert_from_float<T>(res0, res1);
|
|
|
};
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
|