|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#pragma once |
|
|
|
|
|
#include "./FbgemmBuild.h" |
|
|
#include "./UtilsAvx2.h" |
|
|
|
|
|
#include <algorithm> |
|
|
#include <array> |
|
|
#include <cassert> |
|
|
#include <cmath> |
|
|
#include <iomanip> |
|
|
#include <iostream> |
|
|
#include <string> |
|
|
#include <type_traits> |
|
|
|
|
|
#ifndef HAVE_SVE |
|
|
#if defined(__aarch64__) && __ARM_FEATURE_SVE |
|
|
#define HAVE_SVE 1 |
|
|
#include <arm_neon_sve_bridge.h> |
|
|
#include <arm_sve.h> |
|
|
#else |
|
|
#define HAVE_SVE 0 |
|
|
#endif |
|
|
#endif |
|
|
|
|
|
namespace fbgemm { |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
template <typename T> |
|
|
struct is_8bit { |
|
|
static constexpr bool value = |
|
|
std::is_same_v<T, int8_t> || std::is_same_v<T, uint8_t>; |
|
|
}; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
enum class matrix_op_t { NoTranspose, Transpose }; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
enum class inst_set_t { |
|
|
anyarch, |
|
|
avx2, |
|
|
avx512, |
|
|
avx512_ymm, |
|
|
avx512_vnni, |
|
|
avx512_vnni_ymm, |
|
|
sve |
|
|
}; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
enum class optimized_conv_t { |
|
|
depthwise, |
|
|
groupwise, |
|
|
pointwise, |
|
|
fastpath1d, |
|
|
im2col, |
|
|
directconv |
|
|
}; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
enum class impl_type_t { ref, opt }; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
enum class FBGEMM_ENUM_CLASS_API layout_t { KCX, KXC }; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
template <typename T> |
|
|
FBGEMM_API int compare_buffers( |
|
|
const T* ref, |
|
|
const T* test, |
|
|
int m, |
|
|
int n, |
|
|
int ld, |
|
|
size_t max_mismatches_to_report, |
|
|
float atol = 1e-3); |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
template <typename T> |
|
|
void printMatrix( |
|
|
matrix_op_t op, |
|
|
const T* inp, |
|
|
size_t R, |
|
|
size_t C, |
|
|
size_t ld, |
|
|
const std::string& name) { |
|
|
|
|
|
|
|
|
|
|
|
std::cout << name << ":" << "[" << R << ", " << C << "]" << '\n'; |
|
|
bool tr = (op == matrix_op_t::Transpose); |
|
|
for (size_t r = 0; r < R; ++r) { |
|
|
for (size_t c = 0; c < C; ++c) { |
|
|
T res = tr ? inp[c * ld + r] : inp[r * ld + c]; |
|
|
if constexpr (std::is_integral_v<T>) { |
|
|
std::cout << std::setw(5) << static_cast<int64_t>(res) << " "; |
|
|
} else { |
|
|
std::cout << std::setw(5) << res << " "; |
|
|
} |
|
|
} |
|
|
std::cout << '\n'; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
template <typename T> |
|
|
FBGEMM_API void transpose_simd( |
|
|
int64_t M, |
|
|
int64_t N, |
|
|
const T* src, |
|
|
int64_t ld_src, |
|
|
T* dst, |
|
|
int64_t ld_dst); |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
FBGEMM_API void fbgemmForceIsa(inst_set_t ); |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
FBGEMM_API void fbgemmEnableAvx512Ymm(bool ); |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
FBGEMM_API bool fbgemmIsIntelXeonD(); |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
FBGEMM_API bool fbgemmHasAvx512Support(); |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
FBGEMM_API bool fbgemmHasAvx2Support(); |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
FBGEMM_API bool fbgemmHasAvx512VnniSupport(); |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
FBGEMM_API bool fbgemmHasArmNeonSupport(); |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
FBGEMM_API bool fbgemmHasArmSveSupport(); |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
FBGEMM_API bool fbgemmHasArmSve2Support(); |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
FBGEMM_API inst_set_t fbgemmInstructionSet(); |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
FBGEMM_API bool isZmm(inst_set_t ); |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
FBGEMM_API bool isYmm(inst_set_t ); |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
struct FBGEMM_API BlockingFactors { |
|
|
int MR; |
|
|
int NR; |
|
|
int NR_MIN; |
|
|
int ROW_INTERLEAVE; |
|
|
int MCB; |
|
|
int KCB; |
|
|
int NCB; |
|
|
}; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
struct FBGEMM_API thread_type_t { |
|
|
int g_num_threads; |
|
|
int m_num_threads; |
|
|
int n_num_threads; |
|
|
int g_thread_id; |
|
|
int m_thread_id; |
|
|
int n_thread_id; |
|
|
|
|
|
std::string toString() const { |
|
|
std::string out; |
|
|
out += "g num threads: " + std::to_string(g_num_threads) + ", "; |
|
|
out += "m num threads: " + std::to_string(m_num_threads) + ", "; |
|
|
out += "n num threads: " + std::to_string(n_num_threads) + ", "; |
|
|
out += "g thread id: " + std::to_string(g_thread_id) + ", "; |
|
|
out += "m thread id: " + std::to_string(m_thread_id) + ", "; |
|
|
out += "n thread id: " + std::to_string(n_thread_id); |
|
|
return out; |
|
|
} |
|
|
}; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
FBGEMM_API int fbgemmGet2DPartition( |
|
|
int m, |
|
|
int n, |
|
|
int nthreads, |
|
|
int n_align, |
|
|
double aspect_ratio); |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
FBGEMM_API thread_type_t fbgemmGetThreadPartition( |
|
|
int g, |
|
|
int m, |
|
|
int n, |
|
|
int thread_id, |
|
|
int num_threads, |
|
|
int n_align = 64); |
|
|
|
|
|
template <int SIZE, typename T = std::int32_t> |
|
|
std::string arrayToString(const std::array<T, SIZE>& inp) { |
|
|
std::string out = "["; |
|
|
for (int i = 0; i < SIZE; ++i) { |
|
|
out += std::to_string(inp[i]); |
|
|
out += (i != SIZE - 1) ? std::string(", ") : std::string("]"); |
|
|
} |
|
|
return out; |
|
|
} |
|
|
|
|
|
template <typename accT = std::int32_t> |
|
|
bool isValidBlockingFactor(const BlockingFactors* const param) { |
|
|
constexpr bool is_32bit = std::is_same_v<accT, int32_t>; |
|
|
constexpr bool is_16bit = std::is_same_v<accT, int16_t>; |
|
|
static const auto iset = fbgemmInstructionSet(); |
|
|
|
|
|
if constexpr (is_32bit) { |
|
|
if (param->ROW_INTERLEAVE != 4) |
|
|
return false; |
|
|
|
|
|
if (isZmm(iset)) { |
|
|
if (param->NR_MIN != 16 || param->NR % param->NR_MIN) |
|
|
return false; |
|
|
} else if (isYmm(iset)) { |
|
|
if (param->NR_MIN != 8 || param->NR % param->NR_MIN) |
|
|
return false; |
|
|
} |
|
|
} else if constexpr (is_16bit) { |
|
|
if (param->ROW_INTERLEAVE != 2) |
|
|
return false; |
|
|
|
|
|
if (isZmm(iset)) { |
|
|
if (param->NR_MIN != 32 || param->NR % param->NR_MIN) |
|
|
return false; |
|
|
} else if (isYmm(iset)) { |
|
|
if (param->NR_MIN != 16 || param->NR % param->NR_MIN) |
|
|
return false; |
|
|
} |
|
|
} |
|
|
|
|
|
if (param->MCB % param->MR) |
|
|
return false; |
|
|
if (param->NCB % param->NR) |
|
|
return false; |
|
|
if (isZmm(iset)) { |
|
|
if constexpr (is_32bit) { |
|
|
|
|
|
if (param->MR * (param->NR / param->NR_MIN) > 28) |
|
|
return false; |
|
|
} else if constexpr (is_16bit) { |
|
|
|
|
|
if ((param->MR * (param->NR / param->NR_MIN) + |
|
|
(param->NR / param->NR_MIN)) > 28) |
|
|
return false; |
|
|
} |
|
|
|
|
|
} else if (isYmm(iset)) { |
|
|
if (param->MR * (param->NR / param->NR_MIN) > 12) |
|
|
return false; |
|
|
} |
|
|
return true; |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
FBGEMM_API void fbgemmPartition1D( |
|
|
int thread_id, |
|
|
int num_threads, |
|
|
std::int64_t total_work, |
|
|
std::int64_t& start, |
|
|
std::int64_t& end); |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
FBGEMM_API void fbgemmPartition1DBlocked( |
|
|
int thread_id, |
|
|
int num_threads, |
|
|
std::int64_t total_work, |
|
|
int block_size, |
|
|
std::int64_t& start, |
|
|
std::int64_t& end); |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
template <typename K, typename V> |
|
|
FBGEMM_API std::pair<K*, V*> radix_sort_parallel( |
|
|
K* const inp_key_buf, |
|
|
V* const inp_value_buf, |
|
|
K* const tmp_key_buf, |
|
|
V* const tmp_value_buf, |
|
|
const int64_t elements_count, |
|
|
const int64_t max_value, |
|
|
const bool maybe_with_neg_vals = false); |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
FBGEMM_API bool is_radix_sort_accelerated_with_openmp(); |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
FBGEMM_API bool is_autovec_disabled(); |
|
|
FBGEMM_API bool is_autovec_forced(); |
|
|
FBGEMM_API bool is_asmjit_disabled(); |
|
|
FBGEMM_API bool is_stats_enabled(); |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
template <typename OutType> |
|
|
void nbit_embedding_sanity_check( |
|
|
|
|
|
|
|
|
[[maybe_unused]] const int input_bit_rate, |
|
|
[[maybe_unused]] const int output_bit_rate, |
|
|
[[maybe_unused]] const bool no_bag) { |
|
|
assert( |
|
|
(input_bit_rate == 2 || input_bit_rate == 4) && |
|
|
"input_bit_rate must be 2 or 4"); |
|
|
|
|
|
if constexpr (std::is_same_v<OutType, uint8_t>) { |
|
|
assert( |
|
|
(no_bag && input_bit_rate == 4 && output_bit_rate == 4) && |
|
|
"we currently only support int4 to int4 for sequential TBE"); |
|
|
} else { |
|
|
assert( |
|
|
(output_bit_rate == 8 * sizeof(OutType)) && |
|
|
"output_bit_rate should be equal to 8 * sizeof(OutType)"); |
|
|
} |
|
|
} |
|
|
|
|
|
#define WARN_ONCE(...) \ |
|
|
do { \ |
|
|
static bool _warned = false; \ |
|
|
if (!_warned) { \ |
|
|
_warned = true; \ |
|
|
fprintf(stderr, __VA_ARGS__); \ |
|
|
} \ |
|
|
} while (0) |
|
|
|
|
|
} |
|
|
|