|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#pragma once |
|
|
#include <cstdint> |
|
|
#include <functional> |
|
|
|
|
|
#include "fbgemm/FbgemmBuild.h" |
|
|
|
|
|
namespace fbgemm { |
|
|
|
|
|
template < |
|
|
typename InType, |
|
|
typename IndexType, |
|
|
typename OffsetType = std::int32_t, |
|
|
typename OutType = float> |
|
|
class EmbeddingSpMDMKernelSignature { |
|
|
public: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
using Type = std::function<bool( |
|
|
std::int64_t output_size, |
|
|
std::int64_t index_size, |
|
|
std::int64_t data_size, |
|
|
const InType* input, |
|
|
const IndexType* indices, |
|
|
const OffsetType* offsets_or_lengths, |
|
|
const float* weights, |
|
|
OutType* out)>; |
|
|
}; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
template < |
|
|
typename InType, |
|
|
typename IndexType, |
|
|
typename OffsetType = std::int32_t, |
|
|
typename OutType = float, |
|
|
bool THREAD_LOCAL = false> |
|
|
FBGEMM_API typename EmbeddingSpMDMKernelSignature< |
|
|
InType, |
|
|
IndexType, |
|
|
OffsetType, |
|
|
OutType>::Type |
|
|
GenerateEmbeddingSpMDM( |
|
|
const std::int64_t block_size, |
|
|
bool has_weight, |
|
|
bool normalize_by_lengths, |
|
|
int prefetch = 16, |
|
|
bool is_weight_positional = false, |
|
|
bool use_offsets = true, |
|
|
bool is_bf16_out = false, |
|
|
bool is_bf16_in = false); |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
template < |
|
|
typename InType, |
|
|
typename IndexType, |
|
|
typename OffsetType = std::int32_t, |
|
|
typename OutType = float, |
|
|
bool THREAD_LOCAL = false> |
|
|
FBGEMM_API typename EmbeddingSpMDMKernelSignature< |
|
|
InType, |
|
|
IndexType, |
|
|
OffsetType, |
|
|
OutType>::Type |
|
|
GenerateEmbeddingSpMDMWithStrides( |
|
|
const std::int64_t block_size, |
|
|
bool has_weight, |
|
|
bool normalize_by_lengths, |
|
|
int prefetch = 16, |
|
|
bool is_weight_positional = false, |
|
|
bool use_offsets = true, |
|
|
std::int64_t output_stride = -1, |
|
|
std::int64_t input_stride = -1, |
|
|
bool scale_bias_last = true, |
|
|
bool no_bag = false, |
|
|
bool is_bf16_out = false, |
|
|
bool is_bf16_in = false); |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
template < |
|
|
typename IndexType, |
|
|
typename OffsetType = std::int32_t, |
|
|
typename OutType = float> |
|
|
FBGEMM_API typename EmbeddingSpMDMKernelSignature< |
|
|
std::uint8_t, |
|
|
IndexType, |
|
|
OffsetType, |
|
|
OutType>::Type |
|
|
GenerateEmbeddingSpMDMNBit( |
|
|
int bit_rate, |
|
|
const std::int64_t block_size, |
|
|
bool has_weight, |
|
|
bool normalize_by_lengths, |
|
|
int prefetch = 16, |
|
|
bool is_weight_positional = false, |
|
|
bool use_offsets = true); |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
template < |
|
|
typename IndexType, |
|
|
typename OffsetType = std::int32_t, |
|
|
typename OutType = float, |
|
|
bool THREAD_LOCAL = false> |
|
|
FBGEMM_API typename EmbeddingSpMDMKernelSignature< |
|
|
std::uint8_t, |
|
|
IndexType, |
|
|
OffsetType, |
|
|
OutType>::Type |
|
|
GenerateEmbeddingSpMDMNBitWithStrides( |
|
|
const int input_bit_rate, |
|
|
const std::int64_t block_size, |
|
|
bool has_weight, |
|
|
bool normalize_by_lengths, |
|
|
int prefetch = 16, |
|
|
bool is_weight_positional = false, |
|
|
bool use_offsets = true, |
|
|
std::int64_t output_stride = -1, |
|
|
std::int64_t input_stride = -1, |
|
|
bool scale_bias_last = true, |
|
|
const bool is_bf16_out = false, |
|
|
const bool no_bag = false, |
|
|
int output_bit_rate = -1); |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
template < |
|
|
typename IndexType, |
|
|
typename OffsetType = std::int32_t, |
|
|
typename OutType = float> |
|
|
FBGEMM_API typename EmbeddingSpMDMKernelSignature< |
|
|
std::uint8_t, |
|
|
IndexType, |
|
|
OffsetType, |
|
|
OutType>::Type |
|
|
GenerateEmbeddingSpMDMFP8WithStrides( |
|
|
const std::int64_t block_size, |
|
|
bool normalize_by_lengths, |
|
|
bool is_weight_positional = false, |
|
|
bool use_offsets = true, |
|
|
std::int64_t output_stride = -1, |
|
|
std::int64_t input_stride = -1, |
|
|
int exponent_bits = 4, |
|
|
int exponent_bias = 7, |
|
|
bool is_bf16_out = false); |
|
|
|
|
|
template < |
|
|
typename InType, |
|
|
typename IndexType, |
|
|
typename OffsetType = std::int32_t> |
|
|
class EmbeddingSpMDMRowWiseSparseKernelSignature { |
|
|
public: |
|
|
using Type = std::function<bool( |
|
|
std::int64_t output_size, |
|
|
std::int64_t index_size, |
|
|
std::int64_t uncompressed_data_size, |
|
|
|
|
|
const InType* input, |
|
|
const IndexType* indices, |
|
|
const OffsetType* offsets_or_lengths, |
|
|
const float* weights, |
|
|
float* out, |
|
|
const std::int32_t* compressed_indices_table)>; |
|
|
}; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
template < |
|
|
typename InType, |
|
|
typename IndexType, |
|
|
typename OffsetType = std::int32_t> |
|
|
FBGEMM_API typename EmbeddingSpMDMRowWiseSparseKernelSignature< |
|
|
InType, |
|
|
IndexType, |
|
|
OffsetType>::Type |
|
|
GenerateEmbeddingSpMDMRowWiseSparse( |
|
|
const std::int64_t block_size, |
|
|
bool has_weight, |
|
|
bool normalize_by_lengths, |
|
|
int prefetch = 16, |
|
|
bool is_weight_positional = false, |
|
|
bool use_offsets = true); |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
template <typename IndexType, typename OffsetType = std::int32_t> |
|
|
FBGEMM_API typename EmbeddingSpMDMRowWiseSparseKernelSignature< |
|
|
std::uint8_t, |
|
|
IndexType, |
|
|
OffsetType>::Type |
|
|
GenerateEmbeddingSpMDMNBitRowWiseSparse( |
|
|
int bit_rate, |
|
|
const std::int64_t block_size, |
|
|
bool has_weight, |
|
|
bool normalize_by_lengths, |
|
|
int prefetch = 16, |
|
|
bool is_weight_positional = false, |
|
|
bool use_offsets = true); |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
template <typename IndexType> |
|
|
class SparseAdaGradSignature { |
|
|
public: |
|
|
using Type = std::function<int( |
|
|
int num_rows, |
|
|
std::uint64_t param_size, |
|
|
float* w, |
|
|
const float* g, |
|
|
float* h, |
|
|
const IndexType* indices, |
|
|
float epsilon, |
|
|
float lr, |
|
|
float weight_decay, |
|
|
const double* counter, |
|
|
|
|
|
|
|
|
|
|
|
std::int64_t counter_halflife)>; |
|
|
}; |
|
|
|
|
|
template <typename IndexType> |
|
|
FBGEMM_API typename SparseAdaGradSignature<IndexType>::Type |
|
|
GenerateSparseAdaGrad( |
|
|
int block_size, |
|
|
bool rowwise = false, |
|
|
int prefetch = 16, |
|
|
bool use_weight_decay = false); |
|
|
|
|
|
|
|
|
|
|
|
template < |
|
|
typename IndexType, |
|
|
typename OffsetType = std::int32_t, |
|
|
typename DataType = float> |
|
|
class RowWiseSparseAdaGradFusedSignature { |
|
|
public: |
|
|
using Type = std::function<bool( |
|
|
std::int64_t output_size, |
|
|
std::int64_t index_size, |
|
|
std::int64_t data_size, |
|
|
DataType* w, |
|
|
const float* g, |
|
|
float* h, |
|
|
const IndexType* indices, |
|
|
const OffsetType* offsets_or_lengths, |
|
|
float epsilon, |
|
|
float lr)>; |
|
|
}; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
template < |
|
|
typename IndexType, |
|
|
typename OffsetType = std::int32_t, |
|
|
typename DataType = float> |
|
|
FBGEMM_API typename RowWiseSparseAdaGradFusedSignature< |
|
|
IndexType, |
|
|
OffsetType, |
|
|
DataType>::Type |
|
|
GenerateRowWiseSparseAdaGradFused( |
|
|
int block_size, |
|
|
int prefetch = 16, |
|
|
bool use_offsets = true, |
|
|
bool use_stochastic_rounding = true, |
|
|
int grad_stride = -1); |
|
|
|
|
|
namespace internal { |
|
|
|
|
|
template <typename InType, typename IndexType, typename OffsetType> |
|
|
FBGEMM_API bool EmbeddingSpMDMBlockSize1_( |
|
|
const std::int64_t output_size, |
|
|
const std::int64_t index_size, |
|
|
const std::int64_t data_size, |
|
|
const InType* input, |
|
|
const IndexType* indices, |
|
|
const OffsetType* offsets_or_lengths, |
|
|
const float* weights, |
|
|
bool normalize_by_lengths, |
|
|
float* out, |
|
|
bool is_weight_positional = false, |
|
|
bool use_offsets = true, |
|
|
bool is_bf16 = false); |
|
|
|
|
|
#if defined(FBGEMM_FBCODE) || !defined(__aarch64__) |
|
|
template <typename IndexType, bool HAS_WEIGHTS> |
|
|
void compressed_indices_remap_avx512( |
|
|
std::int32_t offsets_numel, |
|
|
const IndexType* indices, |
|
|
const int32_t* compressed_indices_mapping, |
|
|
const IndexType* offsets, |
|
|
const float* weights, |
|
|
IndexType* out_indices, |
|
|
IndexType* out_offsets, |
|
|
float* out_weights); |
|
|
#endif |
|
|
|
|
|
|
|
|
template < |
|
|
typename IndexType, |
|
|
typename OffsetType, |
|
|
typename OutType, |
|
|
bool NoBag, |
|
|
bool EnablePrefetching> |
|
|
FBGEMM_API bool EmbeddingSpMDM8Bit_Sve( |
|
|
const int64_t block_size, |
|
|
const int64_t output_size, |
|
|
const int64_t index_size, |
|
|
const int64_t data_size, |
|
|
const uint8_t* input, |
|
|
const IndexType* indices, |
|
|
const OffsetType* offsets_or_lengths, |
|
|
const float* weights, |
|
|
const bool normalize_by_lengths, |
|
|
OutType* out, |
|
|
const bool is_weight_positional, |
|
|
const bool use_offsets, |
|
|
const int64_t output_stride, |
|
|
const int64_t input_stride, |
|
|
const bool scale_bias_last, |
|
|
const bool is_bf16_out); |
|
|
|
|
|
} |
|
|
|
|
|
template <typename IndexType> |
|
|
FBGEMM_API void compressed_indices_remap( |
|
|
std::int32_t offsets_numel, |
|
|
const IndexType* indices, |
|
|
const int32_t* compressed_indices_mapping, |
|
|
const IndexType* offsets, |
|
|
const float* weights, |
|
|
IndexType* out_indices, |
|
|
IndexType* out_offsets, |
|
|
float* out_weights); |
|
|
|
|
|
} |
|
|
|