Spaces:
Running
Running
File size: 6,038 Bytes
5f923cd | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 | // Copyright 2025 The ODML Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef THIRD_PARTY_ODML_LITERT_LM_RUNTIME_EXECUTOR_LLM_EXECUTOR_PROCESSED_TOKENS_H_
#define THIRD_PARTY_ODML_LITERT_LM_RUNTIME_EXECUTOR_LLM_EXECUTOR_PROCESSED_TOKENS_H_
#include <cstddef>
#include <memory>
#include <utility>
#include <vector>
#include "absl/status/status.h" // from @com_google_absl
#include "absl/types/span.h" // from @com_google_absl
namespace litert::lm {
// Information which is used to process a token.
class TokenData {
public:
explicit TokenData(int token_id) : id_(token_id) {}
TokenData(int token_id, std::vector<float> token_embedding,
std::vector<float> token_per_layer_embedding)
: id_(token_id),
embedding_(std::move(token_embedding)),
per_layer_embedding_(std::move(token_per_layer_embedding)) {}
int id() const { return id_; }
absl::Span<const float> embedding() const { return embedding_; }
std::vector<float>& mutable_embedding() { return embedding_; }
absl::Span<const float> per_layer_embedding() const {
return per_layer_embedding_;
}
std::vector<float>& mutable_per_layer_embedding() {
return per_layer_embedding_;
}
private:
// The token id that is to be processed.
const int id_;
// May contain the embedding corresponding to the token id.
std::vector<float> embedding_;
// May contain the per-layer embedding corresponding to the token id.
std::vector<float> per_layer_embedding_;
};
// Keeps track of processed tokens during the LLM execution.
//
// This class is used by `ProcessedContext` to store the sequence of tokens
// that have been processed so far. It keeps track of both the processed tokens
// and a pending input token, if any, which may be used by backends which
// require an input token to be provided during Decode.
//
// During prefill, one set of processed tokens are maintained.
// During decode, output batch size (or number of output candidates) sets of
// processed tokens are maintained.
class ProcessedTokens {
public:
// Tokens and their corresponding step. Number of tokens will be:
// - Empty if the step does not correspond to the tokens in this
// ProcessedTokens.
// - One if the step is for prefill.
// - Size of the output batch if the steps corresponds to decode.
struct StepAndToken {
int step;
std::vector<std::shared_ptr<TokenData>> token;
};
ProcessedTokens();
ProcessedTokens(const ProcessedTokens&) = default;
ProcessedTokens(ProcessedTokens&&) noexcept = default;
ProcessedTokens& operator=(const ProcessedTokens&) = default;
ProcessedTokens& operator=(ProcessedTokens&&) noexcept = default;
// Returns the number of processed tokens inclusive of the pending input
// token, if any.
int TokenCount() const;
// Reduces the token candidates to 1 with one of given index.
// It will be called when LLM switches from decode to prefill.
absl::Status ReduceTokenCandidates(size_t index);
// Broadcasts the token candidates to the given size.
// It will be called when LLM switches from prefill to decode.
absl::Status BroadcastTokenCandidates(size_t size);
// Returns `pending_input_token_` and its step, if it exists; otherwise,
// the step after the last processed token.
StepAndToken GetNextUnprocessedToken() const;
// Appends the given tokens to the list of processed tokens.
void AddProcessedTokens(const std::vector<int>& token_ids);
// Add token (or tokens during Decode) as "pending" input token(s), which
// indicates that the token has not yet been processed by the LLM, but is part
// of the current context and is to be processed during the next Prefill or
// Decode step. This may be used by backends which require an input token to
// be provided during Decode.
absl::Status AddPendingInputToken(
const std::vector<std::shared_ptr<TokenData>>& token);
// Reverts the processed tokens to the given step. This new step must be
// non-negative and smaller than the current token count.
absl::Status RollBackToStep(int new_step);
// Returns the token at the given `step` or empty if the step does not
// correspond to a token. It may contains tokens more than one during decode
// when decode batch size is greater than one.
std::vector<int> GetTokenAtStep(int step) const;
// Marks the pending input token as processed.
// Returns kNotFoundError if there is no pending input token.
absl::Status MarkPendingInputTokenAsProcessed();
// Returns a deep copy of the complete list of processed tokens, inclusive of
// the pending input token, if any.
std::vector<std::vector<int>> GetCopyOfTokens() const;
// WARNING: This function returns a reference to the internal `tokens_`
// directly, which may not include the pending input token. This method MUST
// NOT be used in code that runs a backend which uses a pending input token.
const std::vector<int>& GetTokensUnsafe() const;
// Invalidates the pending input token, if any.
void InvalidatePendingInputToken();
private:
int GetStep() const;
bool HasPendingInputToken() const;
std::vector<std::shared_ptr<TokenData>> GetPendingInputToken() const;
struct Tokens {
std::vector<int> token_ids;
std::shared_ptr<TokenData> pending_input_token;
};
// tokens_.size() is 1 if prefill or output batch size if decode.
std::vector<Tokens> tokens_;
};
} // namespace litert::lm
#endif // THIRD_PARTY_ODML_LITERT_LM_RUNTIME_EXECUTOR_LLM_EXECUTOR_PROCESSED_TOKENS_H_
|