File size: 914 Bytes
9dd3461
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
#pragma once

#include <ATen/ATen.h>
#include <ATen/core/ivalue.h>

struct EmbeddingPackedParamsBase : public torch::jit::CustomClassHolder {
  virtual at::Tensor embeddingbag_byte(
    const at::Tensor& indices,
    const c10::optional<at::Tensor>& offsets,
    bool pruned_weights,
    const c10::optional<at::Tensor>& per_sample_weights_,
    const c10::optional<at::Tensor>& compressed_indices_mapping,
    bool include_last_offset,
    bool is_embedding_op) = 0;

  virtual at::Tensor embeddingbag_4bit(
    const at::Tensor& indices,
    const c10::optional<at::Tensor>& offsets,
    bool pruned_weights,
    const c10::optional<at::Tensor>& per_sample_weights_,
    const c10::optional<at::Tensor>& compressed_indices_mapping,
    bool include_last_offset,
    bool is_embedding_op) = 0;

  virtual at::Tensor unpack() = 0;

  virtual int64_t bit_rate() const = 0;
  virtual int64_t version() const = 0;
};