Spaces:
Running
Running
File size: 2,929 Bytes
5f923cd | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 | // Copyright 2025 The ODML Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef THIRD_PARTY_ODML_LITERT_LM_RUNTIME_COMPONENTS_TOKEN_ID_UTIL_H_
#define THIRD_PARTY_ODML_LITERT_LM_RUNTIME_COMPONENTS_TOKEN_ID_UTIL_H_
#include <vector>
#include "absl/status/status.h" // from @com_google_absl
#include "absl/status/statusor.h" // from @com_google_absl
#include "absl/types/span.h" // from @com_google_absl
namespace litert::lm {
// Preprocesses the token ids before feeding them to the LLM.
// - token_ids: The token ids to be preprocessed.
// - start_token_id: The token id of the start token. It is used to be prepended
// to the token ids.
// - max_num_tokens: The maximum number of tokens the model can hold (i.e. the
// length of the kv-cache). Setting it to be 0 means no limit.
// - context_length_ratio_threhold: The threshold of the ratio of input context
// length to the max_num_tokens. If the ratio is larger than the threshold,
// the function will return an error, indicating that there is not enough
// space for the model to generate the output. For example, if the max_num
// tokens is 1024 and the context_length_ratio_threhold is 0.9, the function
// will return an error if the input token length is 922 (922 > 1024 * 0.9).
absl::Status PreprocessTokenIds(std::vector<int>& token_ids, int start_token_id,
int max_num_tokens,
float context_length_ratio_threhold = 0.9f);
// Checks if the stop token is found in the decoded token ids.
// - decoded_token_ids: The decoded token ids. The size of the vector span is
// the batch size of the decoded token ids generated by the model.
// - stop_token_id: The token id of the stop token.
// - stop_token_found: The vector to store the information of whether a stop
// token is found in each of the decoded candidate sequences. The function is
// responsible for updating the vector to reflect whether a stop token is
// found in the decoded token ids. The function returns true when all values
// in the vector are true.
absl::StatusOr<bool> StopTokenFound(absl::Span<const int> decoded_token_ids,
const std::vector<int>& stop_token_ids,
std::vector<bool>& stop_token_found);
} // namespace litert::lm
#endif // THIRD_PARTY_ODML_LITERT_LM_RUNTIME_COMPONENTS_TOKEN_ID_UTIL_H_
|