|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#pragma once
|
|
|
#include <cstdint>
|
|
|
#include <functional>
|
|
|
|
|
|
#include "fbgemm/FbgemmBuild.h"
|
|
|
|
|
|
namespace fbgemm {
|
|
|
|
|
|
template <
|
|
|
typename InType,
|
|
|
typename IndexType,
|
|
|
typename OffsetType = std::int32_t,
|
|
|
typename OutType = float>
|
|
|
class EmbeddingSpMDMKernelSignature {
|
|
|
public:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
using Type = std::function<bool(
|
|
|
std::int64_t output_size,
|
|
|
std::int64_t index_size,
|
|
|
std::int64_t data_size,
|
|
|
const InType* input,
|
|
|
const IndexType* indices,
|
|
|
const OffsetType* offsets_or_lengths,
|
|
|
const float* weights,
|
|
|
OutType* out)>;
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
template <
|
|
|
typename InType,
|
|
|
typename IndexType,
|
|
|
typename OffsetType = std::int32_t,
|
|
|
typename OutType = float,
|
|
|
bool THREAD_LOCAL = false>
|
|
|
FBGEMM_API typename EmbeddingSpMDMKernelSignature<
|
|
|
InType,
|
|
|
IndexType,
|
|
|
OffsetType,
|
|
|
OutType>::Type
|
|
|
GenerateEmbeddingSpMDM(
|
|
|
const std::int64_t block_size,
|
|
|
bool has_weight,
|
|
|
bool normalize_by_lengths,
|
|
|
int prefetch = 16,
|
|
|
bool is_weight_positional = false,
|
|
|
bool use_offsets = true,
|
|
|
bool is_bf16_out = false,
|
|
|
bool is_bf16_in = false);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
template <
|
|
|
typename InType,
|
|
|
typename IndexType,
|
|
|
typename OffsetType = std::int32_t,
|
|
|
typename OutType = float,
|
|
|
bool THREAD_LOCAL = false>
|
|
|
FBGEMM_API typename EmbeddingSpMDMKernelSignature<
|
|
|
InType,
|
|
|
IndexType,
|
|
|
OffsetType,
|
|
|
OutType>::Type
|
|
|
GenerateEmbeddingSpMDMWithStrides(
|
|
|
const std::int64_t block_size,
|
|
|
bool has_weight,
|
|
|
bool normalize_by_lengths,
|
|
|
int prefetch = 16,
|
|
|
bool is_weight_positional = false,
|
|
|
bool use_offsets = true,
|
|
|
std::int64_t output_stride = -1,
|
|
|
std::int64_t input_stride = -1,
|
|
|
bool scale_bias_last = true,
|
|
|
bool no_bag = false,
|
|
|
bool is_bf16_out = false,
|
|
|
bool is_bf16_in = false);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
template <
|
|
|
typename IndexType,
|
|
|
typename OffsetType = std::int32_t,
|
|
|
typename OutType = float>
|
|
|
FBGEMM_API typename EmbeddingSpMDMKernelSignature<
|
|
|
std::uint8_t,
|
|
|
IndexType,
|
|
|
OffsetType,
|
|
|
OutType>::Type
|
|
|
GenerateEmbeddingSpMDMNBit(
|
|
|
int bit_rate,
|
|
|
const std::int64_t block_size,
|
|
|
bool has_weight,
|
|
|
bool normalize_by_lengths,
|
|
|
int prefetch = 16,
|
|
|
bool is_weight_positional = false,
|
|
|
bool use_offsets = true);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
template <
|
|
|
typename IndexType,
|
|
|
typename OffsetType = std::int32_t,
|
|
|
typename OutType = float,
|
|
|
bool THREAD_LOCAL = false>
|
|
|
FBGEMM_API typename EmbeddingSpMDMKernelSignature<
|
|
|
std::uint8_t,
|
|
|
IndexType,
|
|
|
OffsetType,
|
|
|
OutType>::Type
|
|
|
GenerateEmbeddingSpMDMNBitWithStrides(
|
|
|
const int input_bit_rate,
|
|
|
const std::int64_t block_size,
|
|
|
bool has_weight,
|
|
|
bool normalize_by_lengths,
|
|
|
int prefetch = 16,
|
|
|
bool is_weight_positional = false,
|
|
|
bool use_offsets = true,
|
|
|
std::int64_t output_stride = -1,
|
|
|
std::int64_t input_stride = -1,
|
|
|
bool scale_bias_last = true,
|
|
|
const bool is_bf16_out = false,
|
|
|
const bool no_bag = false,
|
|
|
int output_bit_rate = -1);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
template <
|
|
|
typename IndexType,
|
|
|
typename OffsetType = std::int32_t,
|
|
|
typename OutType = float>
|
|
|
FBGEMM_API typename EmbeddingSpMDMKernelSignature<
|
|
|
std::uint8_t,
|
|
|
IndexType,
|
|
|
OffsetType,
|
|
|
OutType>::Type
|
|
|
GenerateEmbeddingSpMDMFP8WithStrides(
|
|
|
const std::int64_t block_size,
|
|
|
bool normalize_by_lengths,
|
|
|
bool is_weight_positional = false,
|
|
|
bool use_offsets = true,
|
|
|
std::int64_t output_stride = -1,
|
|
|
std::int64_t input_stride = -1,
|
|
|
int exponent_bits = 4,
|
|
|
int exponent_bias = 7,
|
|
|
bool is_bf16_out = false);
|
|
|
|
|
|
template <
|
|
|
typename InType,
|
|
|
typename IndexType,
|
|
|
typename OffsetType = std::int32_t>
|
|
|
class EmbeddingSpMDMRowWiseSparseKernelSignature {
|
|
|
public:
|
|
|
using Type = std::function<bool(
|
|
|
std::int64_t output_size,
|
|
|
std::int64_t index_size,
|
|
|
std::int64_t uncompressed_data_size,
|
|
|
|
|
|
const InType* input,
|
|
|
const IndexType* indices,
|
|
|
const OffsetType* offsets_or_lengths,
|
|
|
const float* weights,
|
|
|
float* out,
|
|
|
const std::int32_t* compressed_indices_table)>;
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
template <
|
|
|
typename InType,
|
|
|
typename IndexType,
|
|
|
typename OffsetType = std::int32_t>
|
|
|
FBGEMM_API typename EmbeddingSpMDMRowWiseSparseKernelSignature<
|
|
|
InType,
|
|
|
IndexType,
|
|
|
OffsetType>::Type
|
|
|
GenerateEmbeddingSpMDMRowWiseSparse(
|
|
|
const std::int64_t block_size,
|
|
|
bool has_weight,
|
|
|
bool normalize_by_lengths,
|
|
|
int prefetch = 16,
|
|
|
bool is_weight_positional = false,
|
|
|
bool use_offsets = true);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
template <typename IndexType, typename OffsetType = std::int32_t>
|
|
|
FBGEMM_API typename EmbeddingSpMDMRowWiseSparseKernelSignature<
|
|
|
std::uint8_t,
|
|
|
IndexType,
|
|
|
OffsetType>::Type
|
|
|
GenerateEmbeddingSpMDMNBitRowWiseSparse(
|
|
|
int bit_rate,
|
|
|
const std::int64_t block_size,
|
|
|
bool has_weight,
|
|
|
bool normalize_by_lengths,
|
|
|
int prefetch = 16,
|
|
|
bool is_weight_positional = false,
|
|
|
bool use_offsets = true);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
template <typename IndexType>
|
|
|
class SparseAdaGradSignature {
|
|
|
public:
|
|
|
using Type = std::function<int(
|
|
|
int num_rows,
|
|
|
std::uint64_t param_size,
|
|
|
float* w,
|
|
|
const float* g,
|
|
|
float* h,
|
|
|
const IndexType* indices,
|
|
|
float epsilon,
|
|
|
float lr,
|
|
|
float weight_decay,
|
|
|
const double* counter,
|
|
|
|
|
|
|
|
|
|
|
|
std::int64_t counter_halflife)>;
|
|
|
};
|
|
|
|
|
|
template <typename IndexType>
|
|
|
FBGEMM_API typename SparseAdaGradSignature<IndexType>::Type
|
|
|
GenerateSparseAdaGrad(
|
|
|
int block_size,
|
|
|
bool rowwise = false,
|
|
|
int prefetch = 16,
|
|
|
bool use_weight_decay = false);
|
|
|
|
|
|
|
|
|
|
|
|
template <
|
|
|
typename IndexType,
|
|
|
typename OffsetType = std::int32_t,
|
|
|
typename DataType = float>
|
|
|
class RowWiseSparseAdaGradFusedSignature {
|
|
|
public:
|
|
|
using Type = std::function<bool(
|
|
|
std::int64_t output_size,
|
|
|
std::int64_t index_size,
|
|
|
std::int64_t data_size,
|
|
|
DataType* w,
|
|
|
const float* g,
|
|
|
float* h,
|
|
|
const IndexType* indices,
|
|
|
const OffsetType* offsets_or_lengths,
|
|
|
float epsilon,
|
|
|
float lr)>;
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
template <
|
|
|
typename IndexType,
|
|
|
typename OffsetType = std::int32_t,
|
|
|
typename DataType = float>
|
|
|
FBGEMM_API typename RowWiseSparseAdaGradFusedSignature<
|
|
|
IndexType,
|
|
|
OffsetType,
|
|
|
DataType>::Type
|
|
|
GenerateRowWiseSparseAdaGradFused(
|
|
|
int block_size,
|
|
|
int prefetch = 16,
|
|
|
bool use_offsets = true,
|
|
|
bool use_stochastic_rounding = true,
|
|
|
int grad_stride = -1);
|
|
|
|
|
|
namespace internal {
|
|
|
|
|
|
template <typename InType, typename IndexType, typename OffsetType>
|
|
|
FBGEMM_API bool EmbeddingSpMDMBlockSize1_(
|
|
|
const std::int64_t output_size,
|
|
|
const std::int64_t index_size,
|
|
|
const std::int64_t data_size,
|
|
|
const InType* input,
|
|
|
const IndexType* indices,
|
|
|
const OffsetType* offsets_or_lengths,
|
|
|
const float* weights,
|
|
|
bool normalize_by_lengths,
|
|
|
float* out,
|
|
|
bool is_weight_positional = false,
|
|
|
bool use_offsets = true,
|
|
|
bool is_bf16 = false);
|
|
|
|
|
|
template <typename IndexType, bool HAS_WEIGHTS>
|
|
|
void compressed_indices_remap_avx512(
|
|
|
std::int32_t offsets_numel,
|
|
|
const IndexType* indices,
|
|
|
const int32_t* compressed_indices_mapping,
|
|
|
const IndexType* offsets,
|
|
|
const float* weights,
|
|
|
IndexType* out_indices,
|
|
|
IndexType* out_offsets,
|
|
|
float* out_weights);
|
|
|
|
|
|
}
|
|
|
|
|
|
template <typename IndexType>
|
|
|
FBGEMM_API void compressed_indices_remap(
|
|
|
std::int32_t offsets_numel,
|
|
|
const IndexType* indices,
|
|
|
const int32_t* compressed_indices_mapping,
|
|
|
const IndexType* offsets,
|
|
|
const float* weights,
|
|
|
IndexType* out_indices,
|
|
|
IndexType* out_offsets,
|
|
|
float* out_weights);
|
|
|
|
|
|
}
|
|
|
|