| #include <ATen/core/Tensor.h> |
| #include <ATen/Config.h> |
| #include <cstdint> |
|
|
| #ifdef USE_FBGEMM |
| #include <fbgemm/FbgemmEmbedding.h> |
| #endif |
|
|
| namespace at::native { |
|
|
| enum class EmbeddingBagMode { |
| SUM = 0, |
| MEAN = 1, |
| MAX = 2, |
| }; |
|
|
| [[maybe_unused]] static bool operator==(int64_t op1, EmbeddingBagMode op2) { |
| return op1 == static_cast<int64_t>(op2); |
| } |
|
|
| [[maybe_unused]] static bool operator!=(int64_t op1, EmbeddingBagMode op2) { |
| return !(op1 == op2); |
| } |
|
|
| void check_arguments( |
| const Tensor& weight, |
| const Tensor& indices, |
| const Tensor& offsets, |
| const int64_t mode, |
| const std::optional<Tensor>& per_sample_weights, |
| bool include_last_offset); |
|
|
| void make_bag_size_out( |
| Tensor& bag_size_out, |
| const Tensor& offsets, |
| const Tensor& indices, |
| const int64_t mode, |
| const bool include_last_offset, |
| const bool requires_grad); |
|
|
| void make_max_indices_out( |
| Tensor& max_indices_out, |
| const Tensor& weight, |
| const Tensor& indices, |
| const Tensor& offsets, |
| const Tensor& bag_size, |
| const int64_t mode, |
| bool include_last_offset); |
|
|
| void make_offset2bag_out( |
| Tensor& offset2bag, |
| Tensor& output, |
| const Tensor& weight, |
| const Tensor& indices, |
| const Tensor& offsets, |
| const int64_t mode, |
| const std::optional<Tensor>& per_sample_weights, |
| const int64_t padding_idx = -1); |
|
|
| #ifdef USE_FBGEMM |
|
|
| template<bool has_weight, typename TIndex, typename TData> |
| struct _CallbackAndBlockSize { |
| using TCallback = typename fbgemm::EmbeddingSpMDMKernelSignature<TData, TIndex, TIndex, TData>::Type; |
|
|
| int64_t blockSize = -1; |
| TCallback callback = nullptr; |
|
|
| static TCallback generateCallback(int64_t block_size) { |
| return fbgemm::GenerateEmbeddingSpMDM<TData, TIndex, TIndex, TData>( |
| block_size, |
| has_weight, |
| false, |
| 16, |
| false, |
| true); |
| } |
|
|
| _CallbackAndBlockSize() = default; |
|
|
| explicit _CallbackAndBlockSize(std::optional<int64_t> maybe_block_size) |
| : blockSize(maybe_block_size.value_or(-1)) |
| , callback(maybe_block_size.has_value() ? generateCallback(maybe_block_size.value()) : nullptr) |
| {} |
| }; |
|
|
| template<typename... StorageMixins> |
| struct _EmbeddingBagKernelCacheImpl : private StorageMixins... { |
|
|
| _EmbeddingBagKernelCacheImpl() = default; |
| |
| explicit _EmbeddingBagKernelCacheImpl(std::optional<int64_t> maybe_block_size) |
| : StorageMixins(maybe_block_size)... |
| {} |
|
|
| |
| template<bool has_weight, typename TIndex, typename TData> |
| typename _CallbackAndBlockSize<has_weight, TIndex, TData>::TCallback |
| getCallback(int64_t block_size) const { |
| |
| |
| |
| if (block_size != _CallbackAndBlockSize<has_weight, TIndex, TData>::blockSize) { |
| return _CallbackAndBlockSize<has_weight, TIndex, TData>::generateCallback(block_size); |
| } |
| |
| return _CallbackAndBlockSize<has_weight, TIndex, TData>::callback; |
| } |
| }; |
|
|
| |
| |
| using _EmbeddingBagKernelCache = _EmbeddingBagKernelCacheImpl< |
| _CallbackAndBlockSize<true, int32_t, float>, |
| _CallbackAndBlockSize<false, int32_t, float>, |
| _CallbackAndBlockSize<true, int64_t, float>, |
| _CallbackAndBlockSize<false, int64_t, float>, |
| _CallbackAndBlockSize<true, int32_t, unsigned short>, |
| _CallbackAndBlockSize<false, int32_t, unsigned short>, |
| _CallbackAndBlockSize<true, int64_t, unsigned short>, |
| _CallbackAndBlockSize<false, int64_t, unsigned short>>; |
| #else |
| struct _EmbeddingBagKernelCache { |
| explicit _EmbeddingBagKernelCache(std::optional<int64_t> ) {} |
| }; |
| #endif |
|
|
| void _embedding_bag_cpu_impl_out(Tensor& output, Tensor& offset2bag, |
| Tensor& bag_size, Tensor* max_indices, |
| const Tensor &weight, const Tensor &indices, |
| const Tensor &offsets, const int64_t mode = 0, |
| const std::optional<Tensor>& per_sample_weights = std::nullopt, |
| bool include_last_offset = false, |
| int64_t padding_idx = -1, |
| _EmbeddingBagKernelCache* fbgemm_kernel_cache = nullptr); |
|
|
| void _embedding_bag_cpu_out( |
| at::Tensor& output, |
| at::Tensor& offset2bag, |
| at::Tensor& bag_size, |
| at::Tensor* p_max_indices, |
| const at::Tensor& weight, |
| const at::Tensor& indices, |
| const at::Tensor& offsets, |
| const bool scale_grad_by_freq, |
| const int64_t mode, |
| const bool sparse, |
| const std::optional<at::Tensor>& per_sample_weights, |
| const bool include_last_offset, |
| const std::optional<int64_t>& padding_idx, |
| _EmbeddingBagKernelCache* fbgemm_kernel_cache = nullptr); |
|
|
| } |
|
|