// Copyright 2025 The ODML Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_ODML_LITERT_LM_RUNTIME_COMPONENTS_TOKEN_ID_UTIL_H_ #define THIRD_PARTY_ODML_LITERT_LM_RUNTIME_COMPONENTS_TOKEN_ID_UTIL_H_ #include #include "absl/status/status.h" // from @com_google_absl #include "absl/status/statusor.h" // from @com_google_absl #include "absl/types/span.h" // from @com_google_absl namespace litert::lm { // Preprocesses the token ids before feeding them to the LLM. // - token_ids: The token ids to be preprocessed. // - start_token_id: The token id of the start token. It is used to be prepended // to the token ids. // - max_num_tokens: The maximum number of tokens the model can hold (i.e. the // length of the kv-cache). Setting it to be 0 means no limit. // - context_length_ratio_threhold: The threshold of the ratio of input context // length to the max_num_tokens. If the ratio is larger than the threshold, // the function will return an error, indicating that there is not enough // space for the model to generate the output. For example, if the max_num // tokens is 1024 and the context_length_ratio_threhold is 0.9, the function // will return an error if the input token length is 922 (922 > 1024 * 0.9). absl::Status PreprocessTokenIds(std::vector& token_ids, int start_token_id, int max_num_tokens, float context_length_ratio_threhold = 0.9f); // Checks if the stop token is found in the decoded token ids. // - decoded_token_ids: The decoded token ids. The size of the vector span is // the batch size of the decoded token ids generated by the model. // - stop_token_id: The token id of the stop token. // - stop_token_found: The vector to store the information of whether a stop // token is found in each of the decoded candidate sequences. The function is // responsible for updating the vector to reflect whether a stop token is // found in the decoded token ids. The function returns true when all values // in the vector are true. absl::StatusOr StopTokenFound(absl::Span decoded_token_ids, const std::vector& stop_token_ids, std::vector& stop_token_found); } // namespace litert::lm #endif // THIRD_PARTY_ODML_LITERT_LM_RUNTIME_COMPONENTS_TOKEN_ID_UTIL_H_