// Copyright 2025 The ODML Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "runtime/core/session_utils.h" #include // NOLINT: Required for path manipulation. #include #include #include #include #include #include #include #include #include "absl/container/flat_hash_map.h" // from @com_google_absl #include "absl/memory/memory.h" // from @com_google_absl #include "absl/status/status.h" // from @com_google_absl #include "absl/status/statusor.h" // from @com_google_absl #include "absl/strings/str_join.h" // from @com_google_absl #include "absl/strings/string_view.h" // from @com_google_absl #include "litert/cc/litert_tensor_buffer.h" // from @litert #include "litert/test/matchers.h" // from @litert #include "runtime/components/sentencepiece_tokenizer.h" #include "runtime/components/tokenizer.h" #include "runtime/engine/engine_settings.h" #include "runtime/engine/io_types.h" #include "runtime/proto/sampler_params.pb.h" #include "runtime/util/convert_tensor_buffer.h" #include "runtime/util/status_macros.h" #include "runtime/util/test_utils.h" // IWYU pragma: keep namespace litert::lm { namespace { constexpr absl::string_view kTestdataDir = "litert_lm/runtime/components/testdata/"; class ExtendedTokenizer : public Tokenizer { public: static absl::StatusOr> CreateFromFile( absl::string_view model_path) { ASSIGN_OR_RETURN(auto tokenizer, SentencePieceTokenizer::CreateFromFile(model_path)); return absl::WrapUnique(new ExtendedTokenizer(std::move(tokenizer))); } void SetExtendedToken(int token_id, absl::string_view token_str) { extended_tokens_to_id_[token_str] = token_id; id_to_extended_tokens_[token_id] = token_str; } absl::StatusOr> TextToTokenIds( absl::string_view text) override { std::vector token_ids; bool is_extended_token_found = false; do { is_extended_token_found = false; for (const auto& [extended_token_str, extended_token_id] : extended_tokens_to_id_) { auto extended_token_pos = text.find(extended_token_str); if (extended_token_pos != std::string::npos) { // The text before the extended token. ASSIGN_OR_RETURN( auto text_ids, tokenizer_->TextToTokenIds(text.substr(0, extended_token_pos))); token_ids.insert(token_ids.end(), text_ids.begin(), text_ids.end()); token_ids.push_back(extended_token_id); text = text.substr(extended_token_pos + extended_token_str.size()); is_extended_token_found = true; } } } while (is_extended_token_found); if (!text.empty()) { ASSIGN_OR_RETURN(auto text_ids, tokenizer_->TextToTokenIds(text)); token_ids.insert(token_ids.end(), text_ids.begin(), text_ids.end()); } return token_ids; } absl::StatusOr TokenIdsToText( const std::vector& token_ids) override { std::vector token_strs; for (int token_id : token_ids) { if (id_to_extended_tokens_.contains(token_id)) { token_strs.push_back(id_to_extended_tokens_[token_id]); } else { token_strs.push_back(tokenizer_->TokenIdsToText({token_id}).value()); } } return absl::StrJoin(token_strs, ""); } absl::StatusOr TokenToId(absl::string_view token) override { if (extended_tokens_to_id_.contains(token)) { return extended_tokens_to_id_[token]; } return tokenizer_->TokenToId(token); } TokenizerType GetTokenizerType() const override { return tokenizer_->GetTokenizerType(); } std::vector GetTokens() const override { return tokenizer_->GetTokens(); } private: explicit ExtendedTokenizer(std::unique_ptr tokenizer) : tokenizer_(std::move(tokenizer)) {}; absl::flat_hash_map id_to_extended_tokens_; absl::flat_hash_map extended_tokens_to_id_; std::unique_ptr tokenizer_; }; class SessionUtilsTest : public testing::Test { protected: void SetUp() override { auto tokenizer = ExtendedTokenizer::CreateFromFile( (std::filesystem::path(::testing::SrcDir()) / std::string(kTestdataDir) / "sentencepiece.model") .string()); ASSERT_OK(tokenizer); tokenizer.value()->SetExtendedToken(256000, " Букмекерлер"); tokenizer_ = std::move(*tokenizer); } std::unique_ptr tokenizer_; }; TEST_F(SessionUtilsTest, MaybeGetBosString) { SessionConfig session_config = SessionConfig::CreateDefault(); session_config.SetStartTokenId(2); // Corresponds to "" ASSERT_OK_AND_ASSIGN(auto bos_string, MaybeGetBosString(session_config, *tokenizer_)); EXPECT_EQ(bos_string, ""); } TEST_F(SessionUtilsTest, StringToProcessedInputText) { SessionConfig session_config = SessionConfig::CreateDefault(); session_config.SetStartTokenId(2); // Corresponds to "" std::optional benchmark_info; ASSERT_OK_AND_ASSIGN(auto input_text, StringToProcessedInputText( "Hello World!", session_config, *tokenizer_, benchmark_info)); ASSERT_TRUE(input_text.IsTensorBuffer()); ASSERT_OK_AND_ASSIGN(auto text_tensor, input_text.GetPreprocessedTextTensor()); ASSERT_NE(text_tensor, nullptr); LITERT_ASSERT_OK_AND_ASSIGN(auto token_ids_span, ReferTensorBufferAsSpan(*text_tensor)); EXPECT_THAT(std::vector(token_ids_span.begin(), token_ids_span.end()), testing::ElementsAre(2, 90, 547, 58, 735, 210, 466, 2294)); } TEST_F(SessionUtilsTest, ApplyPromptTemplatesFails) { SessionConfig session_config = SessionConfig::CreateDefault(); session_config.SetStartTokenId(2); // Corresponds to "" std::vector inputs_with_bos; inputs_with_bos.emplace_back(InputText("Hello World!")); EXPECT_THAT( ApplyPromptTemplates(inputs_with_bos, ContentType::kFirst, session_config, *tokenizer_, /*is_first_turn=*/true), testing::status::StatusIs(absl::StatusCode::kInvalidArgument, "Input contains bos control token. Control " "token should not be included in the input.")); } TEST_F(SessionUtilsTest, ApplyPromptTemplatesCanHandleEmptyContent) { SessionConfig session_config = SessionConfig::CreateDefault(); session_config.SetStartTokenId(2); // Corresponds to "" { std::vector empty_inputs; ASSERT_OK_AND_ASSIGN( auto templated_single, ApplyPromptTemplates(empty_inputs, ContentType::kFirst, session_config, *tokenizer_, /*is_first_turn=*/true)); ASSERT_EQ(templated_single.size(), 1); EXPECT_THAT(std::get(templated_single[0]).GetRawTextString(), testing::status::IsOkAndHolds("")); } for (const auto& content_type : {ContentType::kFirst, ContentType::kLast, ContentType::kMiddle}) { std::vector empty_inputs; ASSERT_OK_AND_ASSIGN( auto templated_empty, ApplyPromptTemplates(empty_inputs, content_type, session_config, *tokenizer_, /*is_first_turn=*/false)); EXPECT_TRUE(templated_empty.empty()); } } TEST_F(SessionUtilsTest, ApplyPromptTemplatesWithSingleTextChunk) { SessionConfig session_config = SessionConfig::CreateDefault(); session_config.SetStartTokenId(2); session_config.GetMutablePromptTemplates().mutable_user()->set_prefix( "User\n"); session_config.GetMutablePromptTemplates().mutable_user()->set_suffix( "\n"); session_config.GetMutablePromptTemplates().mutable_model()->set_prefix( "Model\n"); { std::vector single_chunk; single_chunk.emplace_back(InputText("Hello ")); ASSERT_OK_AND_ASSIGN( auto templated_single, ApplyPromptTemplates(single_chunk, ContentType::kFirst, session_config, *tokenizer_, /*is_first_turn=*/true)); ASSERT_EQ(templated_single.size(), 2); EXPECT_THAT(std::get(templated_single[0]).GetRawTextString(), testing::status::IsOkAndHolds("")); EXPECT_THAT(std::get(templated_single[1]).GetRawTextString(), testing::status::IsOkAndHolds("User\nHello ")); } { std::vector single_chunk; single_chunk.emplace_back(InputText("world!")); ASSERT_OK_AND_ASSIGN( auto templated_single, ApplyPromptTemplates(single_chunk, ContentType::kMiddle, session_config, *tokenizer_, /*is_first_turn=*/false)); ASSERT_EQ(templated_single.size(), 1); EXPECT_THAT(std::get(templated_single[0]).GetRawTextString(), testing::status::IsOkAndHolds("world!")); } { std::vector single_chunk; single_chunk.emplace_back(InputText("")); ASSERT_OK_AND_ASSIGN( auto templated_single, ApplyPromptTemplates(single_chunk, ContentType::kLast, session_config, *tokenizer_, /*is_first_turn=*/false)); ASSERT_EQ(templated_single.size(), 1); EXPECT_THAT(std::get(templated_single[0]).GetRawTextString(), testing::status::IsOkAndHolds("\nModel\n")); } } TEST_F(SessionUtilsTest, ApplyPromptTemplatesDisabled) { SessionConfig session_config = SessionConfig::CreateDefault(); session_config.SetStartTokenId(2); session_config.GetMutablePromptTemplates().mutable_user()->set_prefix( "User\n"); session_config.GetMutablePromptTemplates().mutable_user()->set_suffix( "\n"); session_config.GetMutablePromptTemplates().mutable_model()->set_prefix( "Model\n"); session_config.SetApplyPromptTemplateInSession(false); // Single text chunk. (is_first_chunk=true, is_last_chunk=true) std::vector single_chunk; single_chunk.emplace_back(InputText("Hello World!")); ASSERT_OK_AND_ASSIGN( auto templated_single, ApplyPromptTemplates(single_chunk, ContentType::kNA, session_config, *tokenizer_, /*is_first_turn=*/true)); ASSERT_EQ(templated_single.size(), 2); EXPECT_THAT(std::get(templated_single[0]).GetRawTextString(), testing::status::IsOkAndHolds("")); EXPECT_THAT(std::get(templated_single[1]).GetRawTextString(), testing::status::IsOkAndHolds("Hello World!")); } TEST_F(SessionUtilsTest, ApplyPromptTemplatesWithTwoTextChunks) { SessionConfig session_config = SessionConfig::CreateDefault(); session_config.SetStartTokenId(2); session_config.GetMutablePromptTemplates().mutable_user()->set_prefix( "User\n"); session_config.GetMutablePromptTemplates().mutable_user()->set_suffix( "\n"); session_config.GetMutablePromptTemplates().mutable_model()->set_prefix( "Model\n"); std::vector two_chunks; two_chunks.emplace_back(InputText("First")); two_chunks.emplace_back(InputText("Second")); ASSERT_OK_AND_ASSIGN( auto templated_two, ApplyPromptTemplates(two_chunks, ContentType::kFirst, session_config, *tokenizer_, /*is_first_turn=*/true)); ASSERT_EQ(templated_two.size(), 3); EXPECT_THAT(std::get(templated_two[0]).GetRawTextString(), testing::status::IsOkAndHolds("")); EXPECT_THAT(std::get(templated_two[1]).GetRawTextString(), testing::status::IsOkAndHolds("User\nFirst")); EXPECT_THAT(std::get(templated_two[2]).GetRawTextString(), testing::status::IsOkAndHolds("Second")); } TEST_F(SessionUtilsTest, ApplyPromptTemplatesDisabledWithTwoTextChunks) { SessionConfig session_config = SessionConfig::CreateDefault(); session_config.SetStartTokenId(2); session_config.GetMutablePromptTemplates().mutable_user()->set_prefix( "User\n"); session_config.GetMutablePromptTemplates().mutable_user()->set_suffix( "\n"); session_config.GetMutablePromptTemplates().mutable_model()->set_prefix( "Model\n"); session_config.SetApplyPromptTemplateInSession(false); // Two text chunks. (First chunk: is_first=true, is_last=false; // Second chunk: is_first=false, is_last=true) std::vector two_chunks; two_chunks.emplace_back(InputText("First")); two_chunks.emplace_back(InputText("Second")); ASSERT_OK_AND_ASSIGN( auto templated_two, ApplyPromptTemplates(two_chunks, ContentType::kNA, session_config, *tokenizer_, /*is_first_turn=*/true)); ASSERT_EQ(templated_two.size(), 3); EXPECT_THAT(std::get(templated_two[0]).GetRawTextString(), testing::status::IsOkAndHolds("")); EXPECT_THAT(std::get(templated_two[1]).GetRawTextString(), testing::status::IsOkAndHolds("First")); EXPECT_THAT(std::get(templated_two[2]).GetRawTextString(), testing::status::IsOkAndHolds("Second")); } TEST_F(SessionUtilsTest, ApplyPromptTemplatesWithThreeTextChunks) { SessionConfig session_config = SessionConfig::CreateDefault(); session_config.SetStartTokenId(2); session_config.GetMutablePromptTemplates().mutable_user()->set_prefix( "User\n"); session_config.GetMutablePromptTemplates().mutable_user()->set_suffix( "\n"); session_config.GetMutablePromptTemplates().mutable_model()->set_prefix( "Model\n"); // Three text chunks. (Middle chunk: is_first=false, is_last=false) std::vector three_chunks; three_chunks.emplace_back(InputText("First")); three_chunks.emplace_back(InputText("Middle")); three_chunks.emplace_back(InputText("Last")); ASSERT_OK_AND_ASSIGN(auto templated_three, ApplyPromptTemplates(three_chunks, ContentType::kFirst, session_config, *tokenizer_, /*is_first_turn=*/true)); ASSERT_EQ(templated_three.size(), 4); EXPECT_THAT(std::get(templated_three[0]).GetRawTextString(), testing::status::IsOkAndHolds("")); EXPECT_THAT(std::get(templated_three[1]).GetRawTextString(), testing::status::IsOkAndHolds("User\nFirst")); EXPECT_THAT(std::get(templated_three[2]).GetRawTextString(), testing::status::IsOkAndHolds("Middle")); EXPECT_THAT(std::get(templated_three[3]).GetRawTextString(), testing::status::IsOkAndHolds("Last")); } TEST_F(SessionUtilsTest, ApplyPromptTemplatesWithMixedChunksTextAndImage) { SessionConfig session_config = SessionConfig::CreateDefault(); session_config.SetStartTokenId(2); session_config.GetMutablePromptTemplates().mutable_user()->set_prefix( "User\n"); session_config.GetMutablePromptTemplates().mutable_user()->set_suffix( "\n"); session_config.GetMutablePromptTemplates().mutable_model()->set_prefix( "Model\n"); // Mixed chunks - text and image. Non-text inputs are passed through. std::vector mixed_chunks; mixed_chunks.emplace_back(InputText("Text1")); mixed_chunks.emplace_back(InputImage("123")); mixed_chunks.emplace_back(InputText("Text2")); ASSERT_OK_AND_ASSIGN( auto templated_mixed, ApplyPromptTemplates(mixed_chunks, ContentType::kFirst, session_config, *tokenizer_, /*is_first_turn=*/true)); ASSERT_EQ(templated_mixed.size(), 4); EXPECT_THAT(std::get(templated_mixed[0]).GetRawTextString(), testing::status::IsOkAndHolds("")); EXPECT_THAT(std::get(templated_mixed[1]).GetRawTextString(), testing::status::IsOkAndHolds("User\nText1")); EXPECT_TRUE(std::holds_alternative(templated_mixed[2])); EXPECT_THAT(std::get(templated_mixed[3]).GetRawTextString(), testing::status::IsOkAndHolds("Text2")); } TEST_F(SessionUtilsTest, ApplyPromptTemplatesWithSubsequentTurn) { SessionConfig session_config = SessionConfig::CreateDefault(); session_config.SetStartTokenId(2); session_config.GetMutablePromptTemplates().mutable_user()->set_prefix( "User\n"); session_config.GetMutablePromptTemplates().mutable_user()->set_suffix( "\n"); session_config.GetMutablePromptTemplates().mutable_model()->set_prefix( "Model\n"); std::vector single_chunk_again; single_chunk_again.emplace_back(InputText("Another turn")); ASSERT_OK_AND_ASSIGN( auto templated_first_turn, ApplyPromptTemplates(single_chunk_again, ContentType::kFirst, session_config, *tokenizer_, /*is_first_turn=*/true)); ASSERT_EQ(templated_first_turn.size(), 2); EXPECT_THAT(std::get(templated_first_turn[0]).GetRawTextString(), testing::status::IsOkAndHolds("")); EXPECT_THAT(std::get(templated_first_turn[1]).GetRawTextString(), testing::status::IsOkAndHolds("User\nAnother turn")); ASSERT_OK_AND_ASSIGN( auto templated_again, ApplyPromptTemplates(single_chunk_again, ContentType::kFirst, session_config, *tokenizer_, /*is_first_turn=*/false)); ASSERT_EQ(templated_again.size(), 1); EXPECT_THAT(std::get(templated_again[0]).GetRawTextString(), testing::status::IsOkAndHolds("User\nAnother turn")); } TEST_F(SessionUtilsTest, ApplyPromptTemplatesWithSingleImageInput) { SessionConfig session_config = SessionConfig::CreateDefault(); session_config.SetStartTokenId(2); session_config.GetMutablePromptTemplates().mutable_user()->set_prefix( "User\n"); session_config.GetMutablePromptTemplates().mutable_user()->set_suffix( "\n"); session_config.GetMutablePromptTemplates().mutable_model()->set_prefix( "Model\n"); // Single image input. Templates are applied to the first and // last chunks. In this case, the image input is both the first and last // chunks, and the text chunks (templates) will be added before and after // the image. std::vector single_image; single_image.emplace_back(InputImage("456")); ASSERT_OK_AND_ASSIGN( auto templated_image, ApplyPromptTemplates(single_image, ContentType::kFirst, session_config, *tokenizer_, /*is_first_turn=*/true)); ASSERT_EQ(templated_image.size(), 3); EXPECT_THAT(std::get(templated_image[0]).GetRawTextString(), testing::status::IsOkAndHolds("")); EXPECT_THAT(std::get(templated_image[1]).GetRawTextString(), testing::status::IsOkAndHolds("User\n")); EXPECT_TRUE(std::holds_alternative(templated_image[2])); } TEST_F(SessionUtilsTest, PreprocessContents) { SessionConfig session_config = SessionConfig::CreateDefault(); session_config.SetStartTokenId(2); std::vector contents; contents.emplace_back(InputText("Hello World!")); std::optional benchmark_info; ASSERT_OK_AND_ASSIGN(auto preprocessed_contents, PreprocessContents(contents, session_config, *tokenizer_, benchmark_info)); ASSERT_EQ(preprocessed_contents.size(), 1); ASSERT_TRUE(std::holds_alternative(preprocessed_contents[0])); const auto& text_data = std::get(preprocessed_contents[0]); ASSERT_TRUE(text_data.IsTensorBuffer()); ASSERT_OK_AND_ASSIGN(auto text_tensor, text_data.GetPreprocessedTextTensor()); ASSERT_NE(text_tensor, nullptr); LITERT_ASSERT_OK_AND_ASSIGN(auto token_ids_span, ReferTensorBufferAsSpan(*text_tensor)); EXPECT_THAT(std::vector(token_ids_span.begin(), token_ids_span.end()), testing::ElementsAre(2, 90, 547, 58, 735, 210, 466, 2294)); } TEST_F(SessionUtilsTest, PreprocessContentsMultimodal) { SessionConfig session_config = SessionConfig::CreateDefault(); session_config.SetStartTokenId(2); std::vector contents; contents.emplace_back(InputText("Hello World!")); std::vector dummy_image_data = {0.1f, 0.2f, 0.3f}; LITERT_ASSERT_OK_AND_ASSIGN( auto image_tensor, CopyToTensorBuffer(dummy_image_data, {1, 1, 1, 3})); contents.emplace_back(InputImage(std::move(image_tensor))); contents.emplace_back(InputImageEnd()); absl::flat_hash_map tensor_map; std::vector map_data = {0.7f, 0.8f}; LITERT_ASSERT_OK_AND_ASSIGN(auto map_tensor, CopyToTensorBuffer(map_data, {1, 2})); tensor_map["key1"] = std::move(map_tensor); contents.emplace_back(InputImage(std::move(tensor_map))); contents.emplace_back(InputImageEnd()); std::vector dummy_audio_data = {0.4f, 0.5f, 0.6f}; LITERT_ASSERT_OK_AND_ASSIGN( auto audio_tensor, CopyToTensorBuffer(dummy_audio_data, {1, 3, 1})); contents.emplace_back(InputAudio(std::move(audio_tensor))); contents.emplace_back(InputAudioEnd()); std::optional benchmark_info; ASSERT_OK_AND_ASSIGN(auto preprocessed_contents, PreprocessContents(contents, session_config, *tokenizer_, benchmark_info)); ASSERT_EQ(preprocessed_contents.size(), 7); ASSERT_TRUE(std::holds_alternative(preprocessed_contents[0])); const auto& text_data = std::get(preprocessed_contents[0]); ASSERT_TRUE(text_data.IsTensorBuffer()); ASSERT_OK_AND_ASSIGN(auto text_tensor, text_data.GetPreprocessedTextTensor()); ASSERT_NE(text_tensor, nullptr); LITERT_ASSERT_OK_AND_ASSIGN(auto token_ids_span, ReferTensorBufferAsSpan(*text_tensor)); EXPECT_THAT(std::vector(token_ids_span.begin(), token_ids_span.end()), testing::ElementsAre(2, 90, 547, 58, 735, 210, 466, 2294)); ASSERT_TRUE(std::holds_alternative(preprocessed_contents[1])); const auto& image_data = std::get(preprocessed_contents[1]); ASSERT_TRUE(image_data.IsTensorBuffer()); ASSERT_OK_AND_ASSIGN(auto img_tensor_out, image_data.GetPreprocessedImageTensor()); LITERT_ASSERT_OK_AND_ASSIGN(auto img_span, ReferTensorBufferAsSpan(*img_tensor_out)); EXPECT_THAT(std::vector(img_span.begin(), img_span.end()), testing::ElementsAre(0.1f, 0.2f, 0.3f)); ASSERT_TRUE(std::holds_alternative(preprocessed_contents[2])); ASSERT_TRUE(std::holds_alternative(preprocessed_contents[3])); const auto& image_map_data = std::get(preprocessed_contents[3]); ASSERT_TRUE(image_map_data.IsTensorBufferMap()); ASSERT_OK_AND_ASSIGN(auto img_map_out, image_map_data.GetPreprocessedImageTensorMap()); ASSERT_TRUE(img_map_out->contains("key1")); LITERT_ASSERT_OK_AND_ASSIGN( auto map_span, ReferTensorBufferAsSpan(img_map_out->at("key1"))); EXPECT_THAT(std::vector(map_span.begin(), map_span.end()), testing::ElementsAre(0.7f, 0.8f)); ASSERT_TRUE(std::holds_alternative(preprocessed_contents[4])); ASSERT_TRUE(std::holds_alternative(preprocessed_contents[5])); const auto& audio_data = std::get(preprocessed_contents[5]); ASSERT_TRUE(audio_data.IsTensorBuffer()); ASSERT_OK_AND_ASSIGN(auto audio_tensor_out, audio_data.GetPreprocessedAudioTensor()); LITERT_ASSERT_OK_AND_ASSIGN( auto audio_span, ReferTensorBufferAsSpan(*audio_tensor_out)); EXPECT_THAT(std::vector(audio_span.begin(), audio_span.end()), testing::ElementsAre(0.4f, 0.5f, 0.6f)); ASSERT_TRUE(std::holds_alternative(preprocessed_contents[6])); } TEST_F(SessionUtilsTest, PreprocessContentsWithEmptyInputText) { SessionConfig session_config = SessionConfig::CreateDefault(); std::vector contents; contents.emplace_back(InputText("")); ASSERT_OK_AND_ASSIGN(auto preprocessed_contents, PreprocessContents(contents, session_config, *tokenizer_, /*benchmark_info=*/std::nullopt)); EXPECT_TRUE(preprocessed_contents.empty()); } } // namespace } // namespace litert::lm