File size: 7,164 Bytes
5f923cd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
// Copyright 2025 The ODML Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//      http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#ifndef THIRD_PARTY_ODML_LITERT_LM_RUNTIME_COMPONENTS_EMBEDDING_LOOKUP_EMBEDDING_LOOKUP_TEXT_H_
#define THIRD_PARTY_ODML_LITERT_LM_RUNTIME_COMPONENTS_EMBEDDING_LOOKUP_EMBEDDING_LOOKUP_TEXT_H_

#include <sys/types.h>

#include <cstddef>
#include <cstdint>
#include <memory>
#include <optional>
#include <string>
#include <utility>
#include <vector>

#include "absl/base/nullability.h"  // from @com_google_absl
#include "absl/status/status.h"  // from @com_google_absl
#include "absl/status/statusor.h"  // from @com_google_absl
#include "absl/types/span.h"  // from @com_google_absl
#include "litert/cc/litert_compiled_model.h"  // from @litert
#include "litert/cc/litert_environment.h"  // from @litert
#include "litert/cc/litert_model.h"  // from @litert
#include "litert/cc/litert_options.h"  // from @litert
#include "litert/cc/litert_ranked_tensor_type.h"  // from @litert
#include "litert/cc/litert_tensor_buffer.h"  // from @litert
#include "runtime/components/embedding_lookup/embedding_lookup.h"

namespace litert::lm {

// Class used for looking up text embeddings on the CPU.
//
// Ideally text embedding lookups should be a part of the main model but there
// are cases where the embedding lookup needs to be done separately for now. For
// example, large embedding tables may use too much memory on the accelerator
// and so they need to be placed on the CPU. Currently there is no mechanism
// to tell a delegate to move embedding lookups to the CPU.
class EmbeddingLookupText : public EmbeddingLookup {
 public:
  ~EmbeddingLookupText() override = default;

  // Creates a EmbeddingLookupText instance. The reference of |model| is kept
  // in the returned instance, so the caller must ensure that |model| outlives
  // the returned instance.  If the model has more than one signature, and
  // signature_key is not provided, the first signature will be used by default.
  static absl::StatusOr<std::unique_ptr<EmbeddingLookupText>> Create(
      const litert::Model* absl_nonnull model,
      std::optional<std::string> signature_key = std::nullopt,
      litert::Environment* absl_nullable env = nullptr);

  // For a given token, looks up the embedding and stores it in the
  // provided vector. The caller is responsible for ensuring that the vector is
  // the correct size for the embedding.
  //
  // This is used for the case where the llm_litert_executor needs to look up
  // embeddings for the current step and then use the result for the next step.
  // At that point, it does not have a LiteRtTensor to store the result in.
  absl::Status LookupDecode(int token,
                            std::vector<float>& decode_output_vector) override;

  // For a given token, looks up the embedding and stores it in the
  // output tensor.
  absl::Status LookupDecode(int token,
                            litert::TensorBuffer* decode_output) override;

  // For a given token, looks up the embedding and stores it in the
  // provided vector. The caller is responsible for ensuring that the vector is
  // the correct size for the embedding model output.
  //
  // This is used for the case where the llm_litert_executor needs to look up
  // embeddings for the current step and then use the result for the next step.
  // At that point, it does not have a LiteRtTensor to store the result in.
  absl::Status LookupPrefill(
      int token, std::vector<float>& prefill_output_vector) override;

  // For a given list of tokens, looks up the embeddings, concatenates them and
  // returns the result through the output tensor.
  //
  // Support is only partially implemented right now. This function only
  // supports the case where the output tensor's 0th dimension is of size
  // 1, its 1st dimension is greater than or equal to tokens.size(), and its
  // subsequent dimensions match the dimensions of the embedding model output.
  // In other words, if the embedding model outputs [B=1, T=1, N, H], then the
  // output tensor must be [1, >=tokens.size(), N, H].
  //
  // bytes_offset is used to indicate what byte to start writing to in the
  // output_tensor. This is used in cases where the output_tensor has already
  // had some embeddings written to it.
  absl::Status LookupPrefill(absl::Span<const int> tokens,
                             litert::TensorBuffer* prefill_output,
                             size_t byte_offset) override;

  // Returns number of floats per token in the output tensor.
  size_t GetFloatsPerToken();

  // Returns the default embedding vector to use when a token is not found in
  // the lookup table.
  const std::vector<float>& GetDefaultEmbeddingVector() const {
    return default_embedding_vector_;
  }

  // Returns the signature key to use for the embedding model.
  std::optional<litert::RankedTensorType> GetOutputBufferType() const {
    return output_buffer_type_;
  }

 protected:
  EmbeddingLookupText(litert::Environment& env,
                      const litert::Model* absl_nonnull model,
                      std::optional<std::string> signature_key)
      : env_(env), model_(*model), signature_key_(signature_key) {}

  // Loads the provided model. This must be called before Lookup.
  absl::Status Initialize();

  // Internal implementation of Lookup for both the single and multiple token
  // cases.
  absl::Status LookupInternal(int token, absl::Span<uint8_t> buffer);

  // The environment for the embedding lookup.
  litert::Environment& env_;
  // The model for the embedding lookup. The actual model instance is owned by
  // the model resources.
  const litert::Model& model_;
  // The compiled model for the embedding model.
  std::optional<litert::CompiledModel> compiled_model_;

  // The input buffer for the embedding model.
  std::vector<litert::TensorBuffer> input_buffers_;

  // The output buffers for the embedding model.
  std::vector<litert::TensorBuffer> output_buffers_;
  // The output buffer type for the embedding model.
  std::optional<litert::RankedTensorType> output_buffer_type_;

  // The size of the output tensor needed for a single token.
  size_t floats_per_token_output_;

  // The default embedding vector to use when a token is not found in the
  // lookup table. This is set to the value of token id 0.
  std::vector<float> default_embedding_vector_;

  // The signature key to use for the embedding model. If not provided, the
  // first signature key will be used.
  std::optional<std::string> signature_key_;
};

}  // namespace litert::lm

#endif  // THIRD_PARTY_ODML_LITERT_LM_RUNTIME_COMPONENTS_EMBEDDING_LOOKUP_EMBEDDING_LOOKUP_TEXT_H_