// Copyright 2025 The ODML Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "runtime/core/session_advanced.h" #include #include // NOLINT: Required for path manipulation. #include #include #include #include #include #include #include #include #include "absl/container/flat_hash_map.h" // from @com_google_absl #include "absl/functional/any_invocable.h" // from @com_google_absl #include "absl/memory/memory.h" // from @com_google_absl #include "absl/status/status.h" // from @com_google_absl #include "absl/status/statusor.h" // from @com_google_absl #include "absl/strings/str_join.h" // from @com_google_absl #include "absl/strings/string_view.h" // from @com_google_absl #include "absl/time/clock.h" // from @com_google_absl #include "absl/time/time.h" // from @com_google_absl #include "litert/cc/litert_environment.h" // from @litert #include "litert/cc/litert_tensor_buffer.h" // from @litert #include "litert/test/matchers.h" // from @litert #include "runtime/components/constrained_decoding/fake_constraint.h" #include "runtime/components/model_resources.h" #include "runtime/components/sentencepiece_tokenizer.h" #include "runtime/components/tokenizer.h" #include "runtime/engine/engine_settings.h" #include "runtime/engine/io_types.h" #include "runtime/executor/audio_executor_settings.h" #include "runtime/executor/executor_settings_base.h" #include "runtime/executor/fake_llm_executor.h" #include "runtime/framework/resource_management/execution_manager.h" #include "runtime/util/convert_tensor_buffer.h" #include "runtime/util/scoped_file.h" #include "runtime/util/status_macros.h" #include "runtime/util/test_utils.h" // IWYU pragma: keep namespace litert::lm { namespace { using ::testing::status::StatusIs; constexpr absl::string_view kTestdataDir = "litert_lm/runtime/components/testdata/"; constexpr absl::string_view kTestAudioModelPath = "litert_lm/runtime/testdata/dummy_audio_only.litertlm"; constexpr int kSpectrogramFrequencySlots = 8; constexpr int kSpectrogramSequenceLength = 10; constexpr int kEmbeddingSequenceLength = 5; constexpr int kEmbeddingDimensions = 6; // Audio embedding tensor will have shape [1, kEmbeddingSequenceLength, // kEmbeddingDimensions]. constexpr std::array kExpectedAudioEmbedding = {0., 0., 0., 0., 0., 0., 0., 1., 2., 3., 3., 3., 0., 1., 2., 4., 4., 4., 1., 2., 3., 5., 5., 5., 0., 1., 2., 4., 4., 4.}; // Mel spectrogram tensor will have shape [1, kSpectrogramSequenceLength, // kSpectrogramFrequencySlots]. constexpr std::array mel_spectrogram_data = { 0., 0., 0., 0., 0., 0., 1., 0., 1., 1., 1., 1., 0., 0., 0., 0., 0., 1., 0., 0., 1., 1., 1., 1., 0., 1., 0., 0., 0., 0., 0., 0., 0., 1., 0., 1., 0., 0., 1., 1., 1., 1., 1., 0., 0., 1., 1., 0., 1., 0., 0., 1., 0., 1., 0., 1., 1., 0., 0., 1., 0., 1., 0., 0., 0., 1., 0., 1., 1., 0., 1., 0., 0., 0., 1., 0., 1., 1., 1., 1.}; absl::StatusOr> CreateFakeLlmExecutor( std::vector> prefill_tokens, std::vector> decode_tokens, std::optional> audio_embedding = std::nullopt) { auto batch_size = decode_tokens.empty() ? 1 : decode_tokens[0].size(); auto fake_executor = std::make_unique( 2560, prefill_tokens, decode_tokens, batch_size, audio_embedding); return std::move(fake_executor); } class ExtendedTokenizer : public Tokenizer { public: static absl::StatusOr> CreateFromFile( absl::string_view model_path) { ASSIGN_OR_RETURN(auto tokenizer, SentencePieceTokenizer::CreateFromFile(model_path)); return absl::WrapUnique(new ExtendedTokenizer(std::move(tokenizer))); } void SetExtendedToken(int token_id, absl::string_view token_str) { extended_tokens_to_id_[token_str] = token_id; id_to_extended_tokens_[token_id] = token_str; } absl::StatusOr> TextToTokenIds( absl::string_view text) override { std::vector token_ids; bool is_extended_token_found = false; do { is_extended_token_found = false; for (const auto& [extended_token_str, extended_token_id] : extended_tokens_to_id_) { auto extended_token_pos = text.find(extended_token_str); if (extended_token_pos != std::string::npos) { // The text before the extended token. ASSIGN_OR_RETURN( auto text_ids, tokenizer_->TextToTokenIds(text.substr(0, extended_token_pos))); token_ids.insert(token_ids.end(), text_ids.begin(), text_ids.end()); token_ids.push_back(extended_token_id); text = text.substr(extended_token_pos + extended_token_str.size()); is_extended_token_found = true; } } } while (is_extended_token_found); if (!text.empty()) { ASSIGN_OR_RETURN(auto text_ids, tokenizer_->TextToTokenIds(text)); token_ids.insert(token_ids.end(), text_ids.begin(), text_ids.end()); } return token_ids; } absl::StatusOr TokenIdsToText( const std::vector& token_ids) override { std::vector token_strs; for (int token_id : token_ids) { if (id_to_extended_tokens_.contains(token_id)) { token_strs.push_back(id_to_extended_tokens_[token_id]); } else { token_strs.push_back(tokenizer_->TokenIdsToText({token_id}).value()); } } return absl::StrJoin(token_strs, ""); } absl::StatusOr TokenToId(absl::string_view token) override { if (extended_tokens_to_id_.contains(token)) { return extended_tokens_to_id_[token]; } return tokenizer_->TokenToId(token); } TokenizerType GetTokenizerType() const override { return tokenizer_->GetTokenizerType(); } std::vector GetTokens() const override { return tokenizer_->GetTokens(); } private: explicit ExtendedTokenizer(std::unique_ptr tokenizer) : tokenizer_(std::move(tokenizer)) {}; absl::flat_hash_map id_to_extended_tokens_; absl::flat_hash_map extended_tokens_to_id_; std::unique_ptr tokenizer_; }; class SessionAdvancedTest : public testing::Test { protected: void SetUp() override { auto tokenizer = ExtendedTokenizer::CreateFromFile( (std::filesystem::path(::testing::SrcDir()) / std::string(kTestdataDir) / "sentencepiece.model") .string()); ASSERT_OK(tokenizer); tokenizer.value()->SetExtendedToken(256000, ""); tokenizer_ = std::move(*tokenizer); model_resources_ = std::unique_ptr(); sampler_params_.set_type(proto::SamplerParameters::TYPE_UNSPECIFIED); } absl::StatusOr> CreateTestSession() { const std::vector> stop_token_ids = {{2294}}; SessionConfig session_config = SessionConfig::CreateDefault(); session_config.GetMutableSamplerParams() = sampler_params_; session_config.GetMutableStopTokenIds() = stop_token_ids; session_config.SetStartTokenId(2); session_config.SetSamplerBackend(Backend::CPU); ASSIGN_OR_RETURN( auto executor, CreateFakeLlmExecutor( // "Hello World!" /*prefill_tokens=*/{{2, 90, 547, 58, 735, 210, 466, 2294}}, // "How's it going?" /*decode_tokens=*/{ {224}, {24}, {8}, {66}, {246}, {18}, {2295}, {2294}})); ASSIGN_OR_RETURN( execution_manager_, ExecutionManager::Create(tokenizer_.get(), model_resources_.get(), std::move(executor), /*vision_executor_settings=*/nullptr, /*audio_executor_settings=*/nullptr, /*litert_env=*/nullptr)); return SessionAdvanced::Create(execution_manager_, tokenizer_.get(), session_config, /*benchmark_info=*/std::nullopt); } std::unique_ptr tokenizer_; std::unique_ptr model_resources_; proto::SamplerParameters sampler_params_; std::shared_ptr execution_manager_; }; absl::StatusOr> CreateAudioExecutorSettings(const std::string& model_path, int max_sequence_length, Backend backend) { ASSIGN_OR_RETURN(auto model_file, ScopedFile::Open(model_path)); auto model_file_ptr = std::make_shared(std::move(model_file)); ASSIGN_OR_RETURN(auto model_assets, ModelAssets::Create(model_file_ptr)); // Create the audio executor settings. ASSIGN_OR_RETURN(auto audio_executor_settings, AudioExecutorSettings::CreateDefault( model_assets, max_sequence_length, backend)); return std::make_unique( std::move(audio_executor_settings)); } absl::AnyInvocable)> CreateStreamingTestCallback( absl::Status& status_ref, TaskState& state_ref, std::vector& texts_ref, bool delay_on_next = false) { return [&status_ref, &state_ref, &texts_ref, delay_on_next](absl::StatusOr responses) mutable { if (!responses.ok()) { status_ref = std::move(responses.status()); return; } state_ref = responses->GetTaskState(); if (IsTaskEndState(state_ref)) { return; } if (delay_on_next) { absl::SleepFor(absl::Milliseconds(50)); } if (!responses->GetTexts().empty()) { EXPECT_EQ(responses->GetTexts().size(), 1); texts_ref.push_back(responses->GetTexts()[0]); } }; } TEST_F(SessionAdvancedTest, RunPrefill) { const std::vector> stop_token_ids = {{2294}}; SessionConfig session_config = SessionConfig::CreateDefault(); session_config.GetMutableSamplerParams() = sampler_params_; session_config.GetMutableStopTokenIds() = stop_token_ids; session_config.SetStartTokenId(2); session_config.SetSamplerBackend(Backend::CPU); ASSERT_OK_AND_ASSIGN( auto executor, CreateFakeLlmExecutor( // The prefill tokens are the expected tokens that will be passed in // at each time the Prefill function is called. The values are the // token ids of the input prompt "Hello World!". // The decode tokens are the expected tokens that will be returned // by the Decode function. The values are the token ids of the // output response "How's it going?" followed by the stop token id // (2294). /*prefill_tokens=*/{{2, 90, 547, 58, 735, 210, 466, 2294}}, /*decode_tokens=*/{ {224}, {24}, {8}, {66}, {246}, {18}, {2295}, {2294}})); ASSERT_OK_AND_ASSIGN( std::shared_ptr execution_manager, ExecutionManager::Create(tokenizer_.get(), model_resources_.get(), std::move(executor), /*vision_executor_settings=*/nullptr, /*audio_executor_settings=*/nullptr, /*litert_env=*/nullptr)); ASSERT_OK_AND_ASSIGN( auto session, SessionAdvanced::Create(execution_manager, tokenizer_.get(), session_config, /*benchmark_info=*/std::nullopt)); std::vector inputs; inputs.emplace_back(InputText("Hello World!")); EXPECT_OK(session->RunPrefill(inputs)); } TEST_F(SessionAdvancedTest, EmptyInputTextReturnsError) { SessionConfig session_config = SessionConfig::CreateDefault(); session_config.SetSamplerBackend(Backend::CPU); ASSERT_OK_AND_ASSIGN(auto executor, CreateFakeLlmExecutor( /*prefill_tokens=*/{{}}, /*decode_tokens=*/{{}})); ASSERT_OK_AND_ASSIGN( std::shared_ptr execution_manager, ExecutionManager::Create(tokenizer_.get(), model_resources_.get(), std::move(executor), /*vision_executor_settings=*/nullptr, /*audio_executor_settings=*/nullptr, /*litert_env=*/nullptr)); ASSERT_OK_AND_ASSIGN( auto session, SessionAdvanced::Create(execution_manager, tokenizer_.get(), session_config, /*benchmark_info=*/std::nullopt)); std::vector inputs; inputs.emplace_back(InputText("")); EXPECT_THAT(session->RunPrefill(inputs), StatusIs(absl::StatusCode::kInvalidArgument, "No token IDs found in preprocessed_contents.")); } TEST_F(SessionAdvancedTest, RunDecodeWithInternalSampler) { const std::vector> stop_token_ids = {{2294}}; SessionConfig session_config = SessionConfig::CreateDefault(); session_config.GetMutableSamplerParams() = sampler_params_; session_config.GetMutableStopTokenIds() = stop_token_ids; session_config.SetStartTokenId(2); ASSERT_OK_AND_ASSIGN( auto executor, CreateFakeLlmExecutor( // "Hello World!" /*prefill_tokens=*/{{2, 90, 547, 58, 735, 210, 466, 2294}}, // "How's it going?" /*decode_tokens=*/{ {224}, {24}, {8}, {66}, {246}, {18}, {2295}, {2294}})); ASSERT_OK_AND_ASSIGN( std::shared_ptr execution_manager, ExecutionManager::Create(tokenizer_.get(), model_resources_.get(), std::move(executor), /*vision_executor_settings=*/nullptr, /*audio_executor_settings=*/nullptr, /*litert_env=*/nullptr)); ASSERT_OK_AND_ASSIGN( auto session, SessionAdvanced::Create(execution_manager, tokenizer_.get(), session_config, /*benchmark_info=*/std::nullopt)); std::vector inputs; inputs.emplace_back(InputText("Hello World!")); EXPECT_OK(session->RunPrefill(inputs)); auto responses = session->RunDecode(); EXPECT_OK(responses); // Expect a single output candidate. EXPECT_EQ(responses->GetTexts().size(), 1); // The response is " How's it going?" since "!" is the stop token which is // not included in the response. EXPECT_EQ(responses->GetTexts()[0], " How's it going?"); } TEST_F(SessionAdvancedTest, RunDecodeWithMaxOutputTokens) { const std::vector> stop_token_ids = {{2294}}; SessionConfig session_config = SessionConfig::CreateDefault(); session_config.GetMutableSamplerParams() = sampler_params_; session_config.GetMutableStopTokenIds() = stop_token_ids; session_config.SetStartTokenId(2); ASSERT_OK_AND_ASSIGN( auto executor, CreateFakeLlmExecutor( // "Hello World!" /*prefill_tokens=*/{{2, 90, 547, 58, 735, 210, 466, 2294}}, // "How's it going?" /*decode_tokens=*/{ {224}, {24}, {8}, {66}, {246}, {18}, {2295}, {2294}})); ASSERT_OK_AND_ASSIGN( std::shared_ptr execution_manager, ExecutionManager::Create(tokenizer_.get(), model_resources_.get(), std::move(executor), /*vision_executor_settings=*/nullptr, /*audio_executor_settings=*/nullptr, /*litert_env=*/nullptr)); ASSERT_OK_AND_ASSIGN( auto session, SessionAdvanced::Create(execution_manager, tokenizer_.get(), session_config, /*benchmark_info=*/std::nullopt)); std::vector inputs; inputs.emplace_back(InputText("Hello World!")); EXPECT_OK(session->RunPrefill(inputs)); auto decode_config = DecodeConfig::CreateDefault(); decode_config.SetMaxOutputTokens(2); auto responses = session->RunDecode(decode_config); EXPECT_OK(responses); // Expect a single output candidate. EXPECT_EQ(responses->GetTexts().size(), 1); EXPECT_EQ(responses->GetTexts()[0], " How'"); } TEST_F(SessionAdvancedTest, RunDecodeWithExternalSampler) { const std::vector> stop_token_ids = {{2294}}; SessionConfig session_config = SessionConfig::CreateDefault(); session_config.GetMutableSamplerParams() = sampler_params_; session_config.GetMutableStopTokenIds() = stop_token_ids; session_config.SetStartTokenId(2); session_config.SetUseExternalSampler(true); session_config.SetSamplerBackend(Backend::CPU); ASSERT_OK_AND_ASSIGN( auto executor, CreateFakeLlmExecutor( // "Hello World!" /*prefill_tokens=*/{{2, 90, 547, 58, 735, 210, 466, 2294}}, // "How's it going?" /*decode_tokens=*/{ {224}, {24}, {8}, {66}, {246}, {18}, {2295}, {2294}})); ASSERT_OK_AND_ASSIGN( std::shared_ptr execution_manager, ExecutionManager::Create(tokenizer_.get(), model_resources_.get(), std::move(executor), /*vision_executor_settings=*/nullptr, /*audio_executor_settings=*/nullptr, /*litert_env=*/nullptr)); ASSERT_OK_AND_ASSIGN( auto session, SessionAdvanced::Create(execution_manager, tokenizer_.get(), session_config, /*benchmark_info=*/std::nullopt)); std::vector inputs; inputs.emplace_back(InputText("Hello World!")); EXPECT_OK(session->RunPrefill(inputs)); auto responses = session->RunDecode(); EXPECT_OK(responses); // Expect a single output candidate. EXPECT_EQ(responses->GetTexts().size(), 1); // The response is " How's it going?" since "!" is the stop token which is // not included in the response. EXPECT_EQ(responses->GetTexts()[0], " How's it going?"); } TEST_F(SessionAdvancedTest, RunDecodeWithMultipleOutputCandidatesWithInternalSampler) { const std::vector> stop_token_ids = {{2294}}; SessionConfig session_config = SessionConfig::CreateDefault(); session_config.GetMutableSamplerParams() = sampler_params_; session_config.GetMutableStopTokenIds() = stop_token_ids; session_config.SetStartTokenId(2); session_config.SetNumOutputCandidates(3); ASSERT_OK_AND_ASSIGN( auto executor, CreateFakeLlmExecutor( // "Hello World!" /*prefill_tokens=*/{{2, 90, 547, 58, 735, 210, 466, 2294}}, // "How's it going?", "Hello World", "How's it going?" /*decode_tokens=*/{{224, 90, 224}, {24, 547, 24}, {8, 58, 8}, {66, 735, 66}, {246, 210, 246}, {18, 466, 18}, {2295, 2294, 2295}, {2294, 0, 2294}})); ASSERT_OK_AND_ASSIGN( std::shared_ptr execution_manager, ExecutionManager::Create(tokenizer_.get(), model_resources_.get(), std::move(executor), /*vision_executor_settings=*/nullptr, /*audio_executor_settings=*/nullptr, /*litert_env=*/nullptr)); ASSERT_OK_AND_ASSIGN( auto session, SessionAdvanced::Create(execution_manager, tokenizer_.get(), session_config, /*benchmark_info=*/std::nullopt)); std::vector inputs; inputs.emplace_back(InputText("Hello World!")); EXPECT_OK(session->RunPrefill(inputs)); auto responses = session->RunDecode(); EXPECT_OK(responses); EXPECT_EQ(responses->GetTexts().size(), 3); // The response is " How's it going?" since "!" is the stop token which is // not included in the response. EXPECT_EQ(responses->GetTexts()[0], " How's it going?"); EXPECT_EQ(responses->GetTexts()[1], " Hello World"); EXPECT_EQ(responses->GetTexts()[2], " How's it going?"); } TEST_F(SessionAdvancedTest, RunDecodeWithMultipleOutputCandidatesWithExternalSampler) { const std::vector> stop_token_ids = {{2294}}; SessionConfig session_config = SessionConfig::CreateDefault(); session_config.GetMutableSamplerParams() = sampler_params_; session_config.GetMutableStopTokenIds() = stop_token_ids; session_config.SetStartTokenId(2); session_config.SetNumOutputCandidates(3); session_config.SetUseExternalSampler(true); session_config.SetSamplerBackend(Backend::CPU); ASSERT_OK_AND_ASSIGN( auto executor, CreateFakeLlmExecutor( // "Hello World!" /*prefill_tokens=*/{{2, 90, 547, 58, 735, 210, 466, 2294}}, // "How's it going?", "Hello World", "How's it going?" /*decode_tokens=*/{{224, 90, 224}, {24, 547, 24}, {8, 58, 8}, {66, 735, 66}, {246, 210, 246}, {18, 466, 18}, {2295, 2294, 2295}, {2294, 0, 2294}})); ASSERT_OK_AND_ASSIGN( std::shared_ptr execution_manager, ExecutionManager::Create(tokenizer_.get(), model_resources_.get(), std::move(executor), /*vision_executor_settings=*/nullptr, /*audio_executor_settings=*/nullptr, /*litert_env=*/nullptr)); ASSERT_OK_AND_ASSIGN( auto session, SessionAdvanced::Create(execution_manager, tokenizer_.get(), session_config, /*benchmark_info=*/std::nullopt)); std::vector inputs; inputs.emplace_back(InputText("Hello World!")); EXPECT_OK(session->RunPrefill(inputs)); auto responses = session->RunDecode(); EXPECT_OK(responses); EXPECT_EQ(responses->GetTexts().size(), 3); // The response is " How's it going?" since "!" is the stop token which is // not included in the response. EXPECT_EQ(responses->GetTexts()[0], " How's it going?"); EXPECT_EQ(responses->GetTexts()[1], " Hello World"); EXPECT_EQ(responses->GetTexts()[2], " How's it going?"); } TEST_F(SessionAdvancedTest, RunDecodeWithConstrainedDecodingWithInternalSampler) { // Fake constraint that expects "'s it". std::vector expected_token_ids = {24, 8, 66, 0}; auto constraint = FakeConstraint(expected_token_ids, /*vocabulary_size=*/2560); const std::vector> stop_token_ids = {{2294}, {0}}; // Top P sampler. proto::SamplerParameters sampler_params; sampler_params.set_type(proto::SamplerParameters::TOP_P); sampler_params.set_k(1); sampler_params.set_temperature(1.0); sampler_params.set_p(0.5); sampler_params.set_seed(1); SessionConfig session_config = SessionConfig::CreateDefault(); session_config.GetMutableSamplerParams() = sampler_params; session_config.GetMutableStopTokenIds() = stop_token_ids; session_config.SetStartTokenId(2); ASSERT_OK_AND_ASSIGN( auto executor, CreateFakeLlmExecutor( /*prefill_tokens=*/{{2, 224}, // The first prefill. {0}}, // The expected prefill tokens that after // stop tokens are found in decoding with // sampler. That is, the last // sampled tokens at stop condition. // "How's it going?" /*decode_tokens=*/{{24}, {8}, {66}, {246}, {18}, {2295}, {2294}})); ASSERT_OK_AND_ASSIGN( std::shared_ptr execution_manager, ExecutionManager::Create(tokenizer_.get(), model_resources_.get(), std::move(executor), /*vision_executor_settings=*/nullptr, /*audio_executor_settings=*/nullptr, /*litert_env=*/nullptr)); ASSERT_OK_AND_ASSIGN( auto session, SessionAdvanced::Create(execution_manager, tokenizer_.get(), session_config, /*benchmark_info=*/std::nullopt)); std::vector inputs; inputs.emplace_back(InputText("How")); auto decode_config = DecodeConfig::CreateDefault(); decode_config.SetConstraint(&constraint); EXPECT_OK(session->RunPrefill(inputs)); ASSERT_OK_AND_ASSIGN(auto responses, session->RunDecode(decode_config)); // Expect a single output candidate. EXPECT_EQ(responses.GetTexts().size(), 1); EXPECT_EQ(responses.GetTexts()[0], "'s it"); } TEST_F(SessionAdvancedTest, RunDecodeWithConstrainedDecodingWithExternalSampler) { // Fake constraint that expects "'s it". std::vector expected_token_ids = {24, 8, 66, 0}; auto constraint = FakeConstraint(expected_token_ids, /*vocabulary_size=*/2560); const std::vector> stop_token_ids = {{2294}, {0}}; // Top P sampler. proto::SamplerParameters sampler_params; sampler_params.set_type(proto::SamplerParameters::TOP_P); sampler_params.set_k(1); sampler_params.set_temperature(1.0); sampler_params.set_p(0.5); sampler_params.set_seed(1); SessionConfig session_config = SessionConfig::CreateDefault(); session_config.GetMutableSamplerParams() = sampler_params; session_config.GetMutableStopTokenIds() = stop_token_ids; session_config.SetStartTokenId(2); session_config.SetUseExternalSampler(true); session_config.SetSamplerBackend(Backend::CPU); ASSERT_OK_AND_ASSIGN( auto executor, CreateFakeLlmExecutor( /*prefill_tokens=*/{{2, 224}, // The first prefill. {0}}, // The expected prefill tokens that after // stop tokens are found in decoding with // sampler. That is, the last // sampled tokens at stop condition. // "How's it going?" /*decode_tokens=*/{{24}, {8}, {66}, {246}, {18}, {2295}, {2294}})); ASSERT_OK_AND_ASSIGN( std::shared_ptr execution_manager, ExecutionManager::Create(tokenizer_.get(), model_resources_.get(), std::move(executor), /*vision_executor_settings=*/nullptr, /*audio_executor_settings=*/nullptr, /*litert_env=*/nullptr)); ASSERT_OK_AND_ASSIGN( auto session, SessionAdvanced::Create(execution_manager, tokenizer_.get(), session_config, /*benchmark_info=*/std::nullopt)); std::vector inputs; inputs.emplace_back(InputText("How")); auto decode_config = DecodeConfig::CreateDefault(); decode_config.SetConstraint(&constraint); EXPECT_OK(session->RunPrefill(inputs)); ASSERT_OK_AND_ASSIGN(auto responses, session->RunDecode(decode_config)); // Expect a single output candidate. EXPECT_EQ(responses.GetTexts().size(), 1); EXPECT_EQ(responses.GetTexts()[0], "'s it"); } absl::AnyInvocable)> CreateTestCallback( bool& done_ref) { return [&done_ref](absl::StatusOr responses) mutable { if (responses.ok() && responses->GetTexts().empty()) { done_ref = true; } }; } TEST_F(SessionAdvancedTest, RunPrefillAsync) { const std::vector> stop_token_ids = {{2294}}; SessionConfig session_config = SessionConfig::CreateDefault(); session_config.GetMutableSamplerParams() = sampler_params_; session_config.SetStartTokenId(2); session_config.GetMutableStopTokenIds() = stop_token_ids; session_config.SetSamplerBackend(Backend::CPU); ASSERT_OK_AND_ASSIGN( auto executor, CreateFakeLlmExecutor( // "Hello World!" /*prefill_tokens=*/{{2, 90, 547, 58, 735, 210, 466, 2294}}, // "How's it going?" /*decode_tokens=*/{ {224}, {24}, {8}, {66}, {246}, {18}, {2295}, {2294}})); ASSERT_OK_AND_ASSIGN( std::shared_ptr execution_manager, ExecutionManager::Create(tokenizer_.get(), model_resources_.get(), std::move(executor), /*vision_executor_settings=*/nullptr, /*audio_executor_settings=*/nullptr, /*litert_env=*/nullptr)); ASSERT_OK_AND_ASSIGN( auto session, SessionAdvanced::Create(execution_manager, tokenizer_.get(), session_config, /*benchmark_info=*/std::nullopt)); std::vector inputs; inputs.emplace_back(InputText("Hello World!")); bool done = false; auto callback = CreateTestCallback(done); EXPECT_OK(session->RunPrefillAsync(inputs, std::move(callback))); // Wait for the async call to finish. EXPECT_OK(execution_manager->WaitUntilAllDone(absl::Seconds(100))); EXPECT_TRUE(done); } TEST_F(SessionAdvancedTest, RunDecodeAsyncWithInternalSampler) { const std::vector> stop_token_ids = {{2294}}; SessionConfig session_config = SessionConfig::CreateDefault(); session_config.GetMutableSamplerParams() = sampler_params_; session_config.SetStartTokenId(2); session_config.GetMutableStopTokenIds() = stop_token_ids; ASSERT_OK_AND_ASSIGN( auto executor, CreateFakeLlmExecutor( // "Hello World!" /*prefill_tokens=*/{{2, 90, 547, 58, 735, 210, 466, 2294}}, // "How's it going?" /*decode_tokens=*/{ {224}, {24}, {8}, {66}, {246}, {18}, {2295}, {2294}})); ASSERT_OK_AND_ASSIGN( std::shared_ptr execution_manager, ExecutionManager::Create(tokenizer_.get(), model_resources_.get(), std::move(executor), /*vision_executor_settings=*/nullptr, /*audio_executor_settings=*/nullptr, /*litert_env=*/nullptr)); ASSERT_OK_AND_ASSIGN( auto session, SessionAdvanced::Create(execution_manager, tokenizer_.get(), session_config, /*benchmark_info=*/std::nullopt)); std::vector inputs; inputs.emplace_back(InputText("Hello World!")); bool done_prefill = false; EXPECT_OK(session->RunPrefillAsync(inputs, CreateTestCallback(done_prefill))); bool done_decode = false; EXPECT_OK(session->RunDecodeAsync(CreateTestCallback(done_decode))); EXPECT_OK(execution_manager->WaitUntilAllDone(absl::Seconds(100))); EXPECT_TRUE(done_prefill); EXPECT_TRUE(done_decode); } TEST_F(SessionAdvancedTest, RunDecodeAsyncWithExternalSampler) { const std::vector> stop_token_ids = {{2294}}; SessionConfig session_config = SessionConfig::CreateDefault(); session_config.GetMutableSamplerParams() = sampler_params_; session_config.SetStartTokenId(2); session_config.GetMutableStopTokenIds() = stop_token_ids; session_config.SetUseExternalSampler(true); session_config.SetSamplerBackend(Backend::CPU); ASSERT_OK_AND_ASSIGN( auto executor, CreateFakeLlmExecutor( // "Hello World!" /*prefill_tokens=*/{{2, 90, 547, 58, 735, 210, 466, 2294}}, // "How's it going?" /*decode_tokens=*/{ {224}, {24}, {8}, {66}, {246}, {18}, {2295}, {2294}})); ASSERT_OK_AND_ASSIGN( std::shared_ptr execution_manager, ExecutionManager::Create(tokenizer_.get(), model_resources_.get(), std::move(executor), /*vision_executor_settings=*/nullptr, /*audio_executor_settings=*/nullptr, /*litert_env=*/nullptr)); ASSERT_OK_AND_ASSIGN( auto session, SessionAdvanced::Create(execution_manager, tokenizer_.get(), session_config, /*benchmark_info=*/std::nullopt)); std::vector inputs; inputs.emplace_back(InputText("Hello World!")); bool done_prefill = false; EXPECT_OK(session->RunPrefillAsync(inputs, CreateTestCallback(done_prefill))); bool done_decode = false; EXPECT_OK(session->RunDecodeAsync(CreateTestCallback(done_decode))); EXPECT_OK(execution_manager->WaitUntilAllDone(absl::Seconds(100))); EXPECT_TRUE(done_prefill); EXPECT_TRUE(done_decode); } TEST_F(SessionAdvancedTest, RunDecodeAsyncWithConstrainedDecodingWithInternalSampler) { // Fake constraint that expects "'s it". std::vector expected_token_ids = {24, 8, 66, 0}; auto constraint = FakeConstraint(expected_token_ids, /*vocabulary_size=*/2560); const std::vector> stop_token_ids = {{2294}, {0}}; // Top P sampler. proto::SamplerParameters sampler_params; sampler_params.set_type(proto::SamplerParameters::TOP_P); sampler_params.set_k(1); sampler_params.set_temperature(1.0); sampler_params.set_p(0.5); sampler_params.set_seed(1); SessionConfig session_config = SessionConfig::CreateDefault(); session_config.GetMutableSamplerParams() = sampler_params; session_config.GetMutableStopTokenIds() = stop_token_ids; session_config.SetStartTokenId(2); ASSERT_OK_AND_ASSIGN( auto executor, CreateFakeLlmExecutor( /*prefill_tokens=*/{{2, 224}, // The first prefill. {0}}, // The expected prefill tokens that after // stop tokens are found in decoding with // sampler. That is, the last // sampled tokens at stop condition. // "How's it going?" /*decode_tokens=*/{{24}, {8}, {66}, {246}, {18}, {2295}, {2294}})); ASSERT_OK_AND_ASSIGN( std::shared_ptr execution_manager, ExecutionManager::Create(tokenizer_.get(), model_resources_.get(), std::move(executor), /*vision_executor_settings=*/nullptr, /*audio_executor_settings=*/nullptr, /*litert_env=*/nullptr)); ASSERT_OK_AND_ASSIGN( auto session, SessionAdvanced::Create(execution_manager, tokenizer_.get(), session_config, /*benchmark_info=*/std::nullopt)); std::vector inputs; inputs.emplace_back(InputText("How")); bool done_prefill = false; EXPECT_OK(session->RunPrefillAsync(inputs, CreateTestCallback(done_prefill))); absl::Status status; TaskState task_state = TaskState::kUnknown; std::vector texts; auto decode_config = DecodeConfig::CreateDefault(); decode_config.SetConstraint(&constraint); ASSERT_OK_AND_ASSIGN(auto task_controller, session->RunDecodeAsync(CreateStreamingTestCallback( status, task_state, texts), decode_config)); EXPECT_OK(task_controller->WaitUntilDone(absl::Seconds(10))); EXPECT_OK(status); EXPECT_EQ(task_state, TaskState::kDone); EXPECT_EQ(texts.size(), 3); EXPECT_THAT(texts, testing::ElementsAre("'", "s", " it")); } TEST_F(SessionAdvancedTest, RunDecodeAsyncWithConstrainedDecodingWithExternalSampler) { // Fake constraint that expects "'s it". std::vector expected_token_ids = {24, 8, 66, 0}; auto constraint = FakeConstraint(expected_token_ids, /*vocabulary_size=*/2560); const std::vector> stop_token_ids = {{2294}, {0}}; // Top P sampler. proto::SamplerParameters sampler_params; sampler_params.set_type(proto::SamplerParameters::TOP_P); sampler_params.set_k(1); sampler_params.set_temperature(1.0); sampler_params.set_p(0.5); sampler_params.set_seed(1); SessionConfig session_config = SessionConfig::CreateDefault(); session_config.GetMutableSamplerParams() = sampler_params; session_config.GetMutableStopTokenIds() = stop_token_ids; session_config.SetStartTokenId(2); session_config.SetUseExternalSampler(true); session_config.SetSamplerBackend(Backend::CPU); ASSERT_OK_AND_ASSIGN( auto executor, CreateFakeLlmExecutor( /*prefill_tokens=*/{{2, 224}, // The first prefill. {0}}, // The expected prefill tokens that after // stop tokens are found in decoding with // sampler. That is, the last // sampled tokens at stop condition. // "How's it going?" /*decode_tokens=*/{{24}, {8}, {66}, {246}, {18}, {2295}, {2294}})); ASSERT_OK_AND_ASSIGN( std::shared_ptr execution_manager, ExecutionManager::Create(tokenizer_.get(), model_resources_.get(), std::move(executor), /*vision_executor_settings=*/nullptr, /*audio_executor_settings=*/nullptr, /*litert_env=*/nullptr)); ASSERT_OK_AND_ASSIGN( auto session, SessionAdvanced::Create(execution_manager, tokenizer_.get(), session_config, /*benchmark_info=*/std::nullopt)); std::vector inputs; inputs.emplace_back(InputText("How")); bool done_prefill = false; EXPECT_OK(session->RunPrefillAsync(inputs, CreateTestCallback(done_prefill))); absl::Status status; TaskState task_state = TaskState::kUnknown; std::vector texts; auto decode_config = DecodeConfig::CreateDefault(); decode_config.SetConstraint(&constraint); ASSERT_OK_AND_ASSIGN(auto task_controller, session->RunDecodeAsync(CreateStreamingTestCallback( status, task_state, texts), decode_config)); EXPECT_OK(task_controller->WaitUntilDone(absl::Seconds(10))); EXPECT_OK(status); EXPECT_EQ(task_state, TaskState::kDone); EXPECT_EQ(texts.size(), 3); EXPECT_THAT(texts, testing::ElementsAre("'", "s", " it")); } TEST_F(SessionAdvancedTest, SaveAndRewindCheckpoint) { ASSERT_OK_AND_ASSIGN(auto session, CreateTestSession()); std::vector inputs; inputs.emplace_back(InputText("Hello World!")); EXPECT_OK(session->RunPrefill(inputs)); EXPECT_OK(session->SaveCheckpoint("checkpoint-1")); auto decode_config = DecodeConfig::CreateDefault(); decode_config.SetMaxOutputTokens(2); ASSERT_OK_AND_ASSIGN(auto responses1, session->RunDecode(decode_config)); EXPECT_EQ(responses1.GetTexts().size(), 1); EXPECT_EQ(responses1.GetTexts()[0], " How'"); EXPECT_OK(session->SaveCheckpoint("checkpoint-2")); EXPECT_OK(session->RewindToCheckpoint("checkpoint-1")); decode_config.SetMaxOutputTokens(2); ASSERT_OK_AND_ASSIGN(auto responses3, session->RunDecode(decode_config)); EXPECT_EQ(responses3.GetTexts().size(), 1); EXPECT_EQ(responses3.GetTexts()[0], " How'"); EXPECT_THAT(session->RewindToCheckpoint("checkpoint-2"), StatusIs(absl::StatusCode::kNotFound)); EXPECT_THAT(session->RewindToCheckpoint("non-existent"), StatusIs(absl::StatusCode::kNotFound)); } TEST_F(SessionAdvancedTest, GetCurrentStep) { ASSERT_OK_AND_ASSIGN(auto session, CreateTestSession()); // Initially step should be 0. ASSERT_OK_AND_ASSIGN(int step1, session->GetCurrentStep()); EXPECT_EQ(step1, 0); std::vector inputs; inputs.emplace_back(InputText("Hello World!")); EXPECT_OK(session->RunPrefill(inputs)); // After prefill, step should be number of prefill tokens. // Fake executor uses 8 tokens for "Hello World!". ASSERT_OK_AND_ASSIGN(int step2, session->GetCurrentStep()); EXPECT_EQ(step2, 8); auto decode_config = DecodeConfig::CreateDefault(); decode_config.SetMaxOutputTokens(2); ASSERT_OK_AND_ASSIGN(auto responses, session->RunDecode(decode_config)); // After decode, step should increase by number of decoded tokens. ASSERT_OK_AND_ASSIGN(int step3, session->GetCurrentStep()); EXPECT_EQ(step3, 10); } TEST_F(SessionAdvancedTest, RunPrefillAndDecodeAsyncWithInternalSampler) { const std::vector> stop_token_ids = {{2294}}; SessionConfig session_config = SessionConfig::CreateDefault(); session_config.GetMutableSamplerParams() = sampler_params_; session_config.GetMutableStopTokenIds() = stop_token_ids; session_config.SetStartTokenId(2); ASSERT_OK_AND_ASSIGN( auto executor, CreateFakeLlmExecutor( // "Hello World!" /*prefill_tokens=*/{{2, 90, 547, 58, 735, 210, 466, 2294}}, // "How's it going?" /*decode_tokens=*/{ {224}, {24}, {8}, {66}, {246}, {18}, {2295}, {2294}})); ASSERT_OK_AND_ASSIGN( std::shared_ptr execution_manager, ExecutionManager::Create(tokenizer_.get(), model_resources_.get(), std::move(executor), /*vision_executor_settings=*/nullptr, /*audio_executor_settings=*/nullptr, /*litert_env=*/nullptr)); ASSERT_OK_AND_ASSIGN( auto session, SessionAdvanced::Create(execution_manager, tokenizer_.get(), session_config, /*benchmark_info=*/std::nullopt)); std::vector inputs; inputs.emplace_back(InputText("Hello World!")); absl::Status status; TaskState task_state = TaskState::kUnknown; std::vector texts; EXPECT_OK(session->RunPrefill(inputs)); ASSERT_OK_AND_ASSIGN(auto task_controller, session->RunDecodeAsync(CreateStreamingTestCallback( status, task_state, texts))); EXPECT_OK(task_controller->WaitUntilDone(absl::Seconds(10))); EXPECT_OK(status); EXPECT_EQ(task_state, TaskState::kDone); EXPECT_EQ(texts.size(), 7); EXPECT_THAT(texts, testing::ElementsAre(" How", "'", "s", " it", " go", "ing", "?")); } TEST_F(SessionAdvancedTest, RunPrefillAndDecodeAsyncWithExternalSampler) { const std::vector> stop_token_ids = {{2294}}; SessionConfig session_config = SessionConfig::CreateDefault(); session_config.GetMutableSamplerParams() = sampler_params_; session_config.GetMutableStopTokenIds() = stop_token_ids; session_config.SetStartTokenId(2); // CPU backend will use internal sampler. session_config.SetSamplerBackend(Backend::CPU); ASSERT_OK_AND_ASSIGN( auto executor, CreateFakeLlmExecutor( // "Hello World!" /*prefill_tokens=*/{{2, 90, 547, 58, 735, 210, 466, 2294}}, // "How's it going?" /*decode_tokens=*/{ {224}, {24}, {8}, {66}, {246}, {18}, {2295}, {2294}})); ASSERT_OK_AND_ASSIGN( std::shared_ptr execution_manager, ExecutionManager::Create(tokenizer_.get(), model_resources_.get(), std::move(executor), /*vision_executor_settings=*/nullptr, /*audio_executor_settings=*/nullptr, /*litert_env=*/nullptr)); ASSERT_OK_AND_ASSIGN( auto session, SessionAdvanced::Create(execution_manager, tokenizer_.get(), session_config, /*benchmark_info=*/std::nullopt)); std::vector inputs; inputs.emplace_back(InputText("Hello World!")); absl::Status status; TaskState task_state = TaskState::kUnknown; std::vector texts; EXPECT_OK(session->RunPrefill(inputs)); ASSERT_OK_AND_ASSIGN(auto task_controller, session->RunDecodeAsync(CreateStreamingTestCallback( status, task_state, texts))); EXPECT_OK(task_controller->WaitUntilDone(absl::Seconds(10))); EXPECT_OK(status); EXPECT_EQ(task_state, TaskState::kDone); EXPECT_EQ(texts.size(), 7); EXPECT_THAT(texts, testing::ElementsAre(" How", "'", "s", " it", " go", "ing", "?")); } TEST_F(SessionAdvancedTest, GenerateContentStream) { const std::vector> stop_token_ids = {{2294}}; SessionConfig session_config = SessionConfig::CreateDefault(); session_config.GetMutableSamplerParams() = sampler_params_; session_config.GetMutableStopTokenIds() = stop_token_ids; session_config.SetStartTokenId(2); ASSERT_OK_AND_ASSIGN( auto executor, CreateFakeLlmExecutor( // "Hello World!" /*prefill_tokens=*/{{2, 90, 547, 58, 735, 210, 466, 2294}}, // "How's it going?" /*decode_tokens=*/{ {224}, {24}, {8}, {66}, {246}, {18}, {2295}, {2294}})); ASSERT_OK_AND_ASSIGN( std::shared_ptr execution_manager, ExecutionManager::Create(tokenizer_.get(), model_resources_.get(), std::move(executor), /*vision_executor_settings=*/nullptr, /*audio_executor_settings=*/nullptr, /*litert_env=*/nullptr)); ASSERT_OK_AND_ASSIGN( auto session, SessionAdvanced::Create(execution_manager, tokenizer_.get(), session_config, /*benchmark_info=*/std::nullopt)); std::vector inputs; inputs.emplace_back(InputText("Hello World!")); absl::Status status; TaskState task_state; std::vector texts; EXPECT_OK(session->GenerateContentStream( inputs, CreateStreamingTestCallback(status, task_state, texts))); EXPECT_OK(session->WaitUntilDone()); EXPECT_OK(status); EXPECT_EQ(task_state, TaskState::kDone); EXPECT_EQ(texts.size(), 7); EXPECT_THAT(texts, testing::ElementsAre(" How", "'", "s", " it", " go", "ing", "?")); } TEST_F(SessionAdvancedTest, RunPrefillEmptyInput) { const std::vector> stop_token_ids = {{2294}}; SessionConfig session_config = SessionConfig::CreateDefault(); session_config.GetMutableSamplerParams() = sampler_params_; session_config.GetMutableStopTokenIds() = stop_token_ids; session_config.SetSamplerBackend(Backend::CPU); ASSERT_OK_AND_ASSIGN( auto executor, CreateFakeLlmExecutor( // "Hello World!" /*prefill_tokens=*/{{2, 90, 547, 58, 735, 210, 466, 2294}}, // "How's it going?" /*decode_tokens=*/{ {224}, {24}, {8}, {66}, {246}, {18}, {2295}, {2294}})); ASSERT_OK_AND_ASSIGN( std::shared_ptr execution_manager, ExecutionManager::Create(tokenizer_.get(), model_resources_.get(), std::move(executor), /*vision_executor_settings=*/nullptr, /*audio_executor_settings=*/nullptr, /*litert_env=*/nullptr)); ASSERT_OK_AND_ASSIGN( auto session, SessionAdvanced::Create(execution_manager, tokenizer_.get(), session_config, /*benchmark_info=*/std::nullopt)); std::vector inputs; EXPECT_THAT(session->RunPrefill(inputs), StatusIs(absl::StatusCode::kInvalidArgument, "No token IDs found in preprocessed_contents.")); } TEST_F(SessionAdvancedTest, RunPrefillAsyncFailed) { // Configure the executor to fail at prefill. ASSERT_OK_AND_ASSIGN( auto executor, CreateFakeLlmExecutor( // "Hello World!" /*prefill_tokens=*/{{2, 90, 547, 58, 735, 210, 466, 2294}}, // "How's it going?" /*decode_tokens=*/{ {224}, {24}, {8}, {66}, {246}, {18}, {2295}, {2294}})); auto* fake_executor = static_cast(executor.get()); fake_executor->SetPrefillStatus(absl::InternalError("Prefill failed")); const std::vector> stop_token_ids = {{2294}}; SessionConfig session_config = SessionConfig::CreateDefault(); session_config.GetMutableSamplerParams() = sampler_params_; session_config.GetMutableStopTokenIds() = stop_token_ids; session_config.SetStartTokenId(2); session_config.SetSamplerBackend(Backend::CPU); ASSERT_OK_AND_ASSIGN( std::shared_ptr execution_manager, ExecutionManager::Create(tokenizer_.get(), model_resources_.get(), std::move(executor), /*vision_executor_settings=*/nullptr, /*audio_executor_settings=*/nullptr, /*litert_env=*/nullptr)); ASSERT_OK_AND_ASSIGN( auto session, SessionAdvanced::Create(execution_manager, tokenizer_.get(), session_config, /*benchmark_info=*/std::nullopt)); std::vector inputs; inputs.emplace_back(InputText("Hello World!")); absl::Status status; TaskState task_state = TaskState::kUnknown; std::vector texts; EXPECT_OK(session->RunPrefillAsync( inputs, CreateStreamingTestCallback(status, task_state, texts))); EXPECT_OK(execution_manager->WaitUntilAllDone(absl::Seconds(10))); EXPECT_FALSE(status.ok()); EXPECT_EQ(task_state, TaskState::kProcessing); EXPECT_THAT(status, StatusIs(absl::StatusCode::kInternal, "Prefill failed")); } TEST_F(SessionAdvancedTest, RunDecodeAsyncFailed) { // Configure the executor to fail at decode. ASSERT_OK_AND_ASSIGN( auto executor, CreateFakeLlmExecutor( // "Hello World!" /*prefill_tokens=*/{{2, 90, 547, 58, 735, 210, 466, 2294}}, // "How's it going?" /*decode_tokens=*/{ {224}, {24}, {8}, {66}, {246}, {18}, {2295}, {2294}})); auto* fake_executor = static_cast(executor.get()); fake_executor->SetDecodeStatus(absl::InternalError("Decode failed")); const std::vector> stop_token_ids = {{2294}}; SessionConfig session_config = SessionConfig::CreateDefault(); session_config.GetMutableSamplerParams() = sampler_params_; session_config.GetMutableStopTokenIds() = stop_token_ids; session_config.SetStartTokenId(2); session_config.SetSamplerBackend(Backend::CPU); ASSERT_OK_AND_ASSIGN( std::shared_ptr execution_manager, ExecutionManager::Create(tokenizer_.get(), model_resources_.get(), std::move(executor), /*vision_executor_settings=*/nullptr, /*audio_executor_settings=*/nullptr, /*litert_env=*/nullptr)); ASSERT_OK_AND_ASSIGN( auto session, SessionAdvanced::Create(execution_manager, tokenizer_.get(), session_config, /*benchmark_info=*/std::nullopt)); std::vector inputs; inputs.emplace_back(InputText("Hello World!")); absl::Status status; TaskState task_state = TaskState::kUnknown; std::vector texts; EXPECT_OK(session->RunPrefill(inputs)); ASSERT_OK_AND_ASSIGN(auto task_controller, session->RunDecodeAsync(CreateStreamingTestCallback( status, task_state, texts))); EXPECT_OK(task_controller->WaitUntilDone(absl::Seconds(10))); EXPECT_FALSE(status.ok()); EXPECT_EQ(task_state, TaskState::kProcessing); EXPECT_THAT(status, StatusIs(absl::StatusCode::kInternal, "Decode failed")); } TEST_F(SessionAdvancedTest, RunDecodeAsyncWithCancellationWithInternalSampler) { // Configure the executor to have a delay to simulate a long-running task. ASSERT_OK_AND_ASSIGN( auto fake_executor, CreateFakeLlmExecutor( // "Hello World!" /*prefill_tokens=*/{{2, 90, 547, 58, 735, 210, 466, 2294}}, // "How's it going?" /*decode_tokens=*/{ {224}, {24}, {8}, {66}, {246}, {18}, {2295}, {2294}})); fake_executor->SetDecodeDelay(absl::Milliseconds(200)); const std::vector> stop_token_ids = {{2294}}; SessionConfig session_config = SessionConfig::CreateDefault(); session_config.GetMutableSamplerParams() = sampler_params_; session_config.GetMutableStopTokenIds() = stop_token_ids; session_config.SetStartTokenId(2); ASSERT_OK_AND_ASSIGN( std::shared_ptr execution_manager, ExecutionManager::Create(tokenizer_.get(), model_resources_.get(), std::move(fake_executor), /*vision_executor_settings=*/nullptr, /*audio_executor_settings=*/nullptr, /*litert_env=*/nullptr)); ASSERT_OK_AND_ASSIGN( auto session, SessionAdvanced::Create(execution_manager, tokenizer_.get(), session_config, /*benchmark_info=*/std::nullopt)); std::vector inputs; inputs.emplace_back(InputText("Hello World!")); absl::Status status; TaskState task_state = TaskState::kUnknown; std::vector responses; EXPECT_OK(session->RunPrefill(inputs)); ASSERT_OK_AND_ASSIGN(auto task_controller, session->RunDecodeAsync(CreateStreamingTestCallback( status, task_state, responses, /*delay_on_next=*/true))); // Wait for a short time to ensure the decoding has started. absl::SleepFor(absl::Milliseconds(100)); // Cancel the process. session->CancelProcess(); // Wait for the callback to be done. EXPECT_OK(task_controller->WaitUntilDone(absl::Seconds(10))); EXPECT_OK(status); EXPECT_EQ(task_state, TaskState::kCancelled); } TEST_F(SessionAdvancedTest, RunDecodeAsyncWithCancellationWithExternalSampler) { // Configure the executor to have a delay to simulate a long-running task. ASSERT_OK_AND_ASSIGN( auto fake_executor, CreateFakeLlmExecutor( // "Hello World!" /*prefill_tokens=*/{{2, 90, 547, 58, 735, 210, 466, 2294}}, // "How's it going?" /*decode_tokens=*/{ {224}, {24}, {8}, {66}, {246}, {18}, {2295}, {2294}})); fake_executor->SetDecodeDelay(absl::Milliseconds(200)); const std::vector> stop_token_ids = {{2294}}; SessionConfig session_config = SessionConfig::CreateDefault(); session_config.GetMutableSamplerParams() = sampler_params_; session_config.GetMutableStopTokenIds() = stop_token_ids; session_config.SetStartTokenId(2); session_config.SetUseExternalSampler(true); session_config.SetSamplerBackend(Backend::CPU); ASSERT_OK_AND_ASSIGN( std::shared_ptr execution_manager, ExecutionManager::Create(tokenizer_.get(), model_resources_.get(), std::move(fake_executor), /*vision_executor_settings=*/nullptr, /*audio_executor_settings=*/nullptr, /*litert_env=*/nullptr)); ASSERT_OK_AND_ASSIGN( auto session, SessionAdvanced::Create(execution_manager, tokenizer_.get(), session_config, /*benchmark_info=*/std::nullopt)); std::vector inputs; inputs.emplace_back(InputText("Hello World!")); absl::Status status; TaskState task_state = TaskState::kUnknown; std::vector responses; EXPECT_OK(session->RunPrefill(inputs)); ASSERT_OK_AND_ASSIGN(auto task_controller, session->RunDecodeAsync(CreateStreamingTestCallback( status, task_state, responses, /*delay_on_next=*/true))); // Wait for a short time to ensure the decoding has started. absl::SleepFor(absl::Milliseconds(100)); // Cancel the process. session->CancelProcess(); // Wait for the callback to be done. EXPECT_OK(task_controller->WaitUntilDone(absl::Seconds(10))); EXPECT_OK(status); EXPECT_EQ(task_state, TaskState::kCancelled); } TEST_F(SessionAdvancedTest, RunDecodeAsyncWithTaskCancellationWithInternalSampler) { // Configure the executor to have a delay to simulate a long-running task. ASSERT_OK_AND_ASSIGN( auto fake_executor, CreateFakeLlmExecutor( // "Hello World!" /*prefill_tokens=*/{{2, 90, 547, 58, 735, 210, 466, 2294}}, // "How's it going?" /*decode_tokens=*/{ {224}, {24}, {8}, {66}, {246}, {18}, {2295}, {2294}})); fake_executor->SetDecodeDelay(absl::Milliseconds(200)); const std::vector> stop_token_ids = {{2294}}; SessionConfig session_config = SessionConfig::CreateDefault(); session_config.GetMutableSamplerParams() = sampler_params_; session_config.GetMutableStopTokenIds() = stop_token_ids; session_config.SetStartTokenId(2); ASSERT_OK_AND_ASSIGN( std::shared_ptr execution_manager, ExecutionManager::Create(tokenizer_.get(), model_resources_.get(), std::move(fake_executor), /*vision_executor_settings=*/nullptr, /*audio_executor_settings=*/nullptr, /*litert_env=*/nullptr)); ASSERT_OK_AND_ASSIGN( auto session, SessionAdvanced::Create(execution_manager, tokenizer_.get(), session_config, /*benchmark_info=*/std::nullopt)); std::vector inputs; inputs.emplace_back(InputText("Hello World!")); absl::Status status; TaskState task_state = TaskState::kUnknown; std::vector responses; EXPECT_OK(session->RunPrefill(inputs)); ASSERT_OK_AND_ASSIGN( auto task_controller, session->RunDecodeAsync(CreateStreamingTestCallback( status, task_state, responses, /*delay_on_next=*/true))); // Wait for a short time to ensure the decoding has started. absl::SleepFor(absl::Milliseconds(100)); // Cancel the task. EXPECT_OK(task_controller->Cancel()); // Wait for the callback to be done. EXPECT_OK(task_controller->WaitUntilDone(absl::Seconds(10))); EXPECT_OK(status); EXPECT_EQ(task_state, TaskState::kCancelled); } TEST_F(SessionAdvancedTest, RunDecodeAsyncWithTaskCancellationWithExternalSampler) { // Configure the executor to have a delay to simulate a long-running task. ASSERT_OK_AND_ASSIGN( auto fake_executor, CreateFakeLlmExecutor( // "Hello World!" /*prefill_tokens=*/{{2, 90, 547, 58, 735, 210, 466, 2294}}, // "How's it going?" /*decode_tokens=*/{ {224}, {24}, {8}, {66}, {246}, {18}, {2295}, {2294}})); fake_executor->SetDecodeDelay(absl::Milliseconds(200)); const std::vector> stop_token_ids = {{2294}}; SessionConfig session_config = SessionConfig::CreateDefault(); session_config.GetMutableSamplerParams() = sampler_params_; session_config.GetMutableStopTokenIds() = stop_token_ids; session_config.SetStartTokenId(2); session_config.SetUseExternalSampler(true); session_config.SetSamplerBackend(Backend::CPU); ASSERT_OK_AND_ASSIGN( std::shared_ptr execution_manager, ExecutionManager::Create(tokenizer_.get(), model_resources_.get(), std::move(fake_executor), /*vision_executor_settings=*/nullptr, /*audio_executor_settings=*/nullptr, /*litert_env=*/nullptr)); ASSERT_OK_AND_ASSIGN( auto session, SessionAdvanced::Create(execution_manager, tokenizer_.get(), session_config, /*benchmark_info=*/std::nullopt)); std::vector inputs; inputs.emplace_back(InputText("Hello World!")); absl::Status status; TaskState task_state = TaskState::kUnknown; std::vector responses; EXPECT_OK(session->RunPrefill(inputs)); ASSERT_OK_AND_ASSIGN( auto task_controller, session->RunDecodeAsync(CreateStreamingTestCallback( status, task_state, responses, /*delay_on_next=*/true))); // Wait for a short time to ensure the decoding has started. absl::SleepFor(absl::Milliseconds(100)); // Cancel the task. EXPECT_OK(task_controller->Cancel()); // Wait for the callback to be done. EXPECT_OK(task_controller->WaitUntilDone(absl::Seconds(10))); EXPECT_OK(status); EXPECT_EQ(task_state, TaskState::kCancelled); } class SessionAdvancedCancellationTest : public testing::TestWithParam { protected: void SetUp() override { auto tokenizer = ExtendedTokenizer::CreateFromFile( (std::filesystem::path(::testing::SrcDir()) / std::string(kTestdataDir) / "sentencepiece.model") .string()); ASSERT_OK(tokenizer); tokenizer.value()->SetExtendedToken(256000, ""); tokenizer_ = std::move(*tokenizer); model_resources_ = std::unique_ptr(); sampler_params_.set_type(proto::SamplerParameters::TYPE_UNSPECIFIED); } bool use_benchmark_info_ = GetParam(); std::unique_ptr tokenizer_; std::unique_ptr model_resources_; proto::SamplerParameters sampler_params_; }; TEST_P(SessionAdvancedCancellationTest, RunDecodeAsyncCancelThenGenerateWithBenchmarkWithInternalSamplerFailed) { // Configure the executor to have a delay to simulate a long-running task. ASSERT_OK_AND_ASSIGN( auto fake_executor, CreateFakeLlmExecutor( // "Hello World!" /*prefill_tokens=*/{{2, 90, 547, 58, 735, 210, 466, 2294}, // The second prefill doesn't have bos token. {90, 547, 58, 735, 210, 466, 2294}}, // "How's it going?" /*decode_tokens=*/{ {224}, {24}, {8}, {66}, {246}, {18}, {2295}, {2294}})); fake_executor->SetDecodeDelay(absl::Milliseconds(200)); const std::vector> stop_token_ids = {{2294}}; SessionConfig session_config = SessionConfig::CreateDefault(); session_config.GetMutableSamplerParams() = sampler_params_; session_config.GetMutableStopTokenIds() = stop_token_ids; session_config.SetStartTokenId(2); std::optional benchmark_info; if (use_benchmark_info_) { proto::BenchmarkParams benchmark_params; benchmark_info.emplace(benchmark_params); } ASSERT_OK_AND_ASSIGN( std::shared_ptr execution_manager, ExecutionManager::Create(tokenizer_.get(), model_resources_.get(), std::move(fake_executor), /*vision_executor_settings=*/nullptr, /*audio_executor_settings=*/nullptr, /*litert_env=*/nullptr)); ASSERT_OK_AND_ASSIGN( auto session, SessionAdvanced::Create(execution_manager, tokenizer_.get(), session_config, benchmark_info)); std::vector inputs; inputs.emplace_back(InputText("Hello World!")); absl::Status status; TaskState task_state = TaskState::kUnknown; std::vector responses; EXPECT_OK(session->RunPrefill(inputs)); ASSERT_OK_AND_ASSIGN( auto task_controller, session->RunDecodeAsync(CreateStreamingTestCallback( status, task_state, responses, /*delay_on_next=*/true))); // Cancel the process. session->CancelProcess(); // Wait for the callback to be done. EXPECT_OK(task_controller->WaitUntilDone(absl::Seconds(10))); EXPECT_OK(status); EXPECT_EQ(task_state, TaskState::kCancelled); // Generate again after cancellation. // The second generation should succeed. status = absl::OkStatus(); responses.clear(); EXPECT_OK(session->RunPrefill(inputs)); ASSERT_OK_AND_ASSIGN( task_controller, session->RunDecodeAsync(CreateStreamingTestCallback( status, task_state, responses, /*delay_on_next=*/true))); EXPECT_OK(task_controller->WaitUntilDone(absl::Seconds(10))); EXPECT_OK(status); EXPECT_EQ(task_state, TaskState::kDependentTaskCancelled); } TEST_P(SessionAdvancedCancellationTest, RunDecodeAsyncCancelThenGenerateWithBenchmarkWithExternalSamplerFailed) { // Configure the executor to have a delay to simulate a long-running task. ASSERT_OK_AND_ASSIGN( auto fake_executor, CreateFakeLlmExecutor( // "Hello World!" /*prefill_tokens=*/{{2, 90, 547, 58, 735, 210, 466, 2294}, // The second prefill doesn't have bos token. {90, 547, 58, 735, 210, 466, 2294}}, // "How's it going?" /*decode_tokens=*/{ {224}, {24}, {8}, {66}, {246}, {18}, {2295}, {2294}})); fake_executor->SetDecodeDelay(absl::Milliseconds(200)); const std::vector> stop_token_ids = {{2294}}; SessionConfig session_config = SessionConfig::CreateDefault(); session_config.GetMutableSamplerParams() = sampler_params_; session_config.GetMutableStopTokenIds() = stop_token_ids; session_config.SetStartTokenId(2); session_config.SetUseExternalSampler(true); session_config.SetSamplerBackend(Backend::CPU); std::optional benchmark_info; if (use_benchmark_info_) { proto::BenchmarkParams benchmark_params; benchmark_info.emplace(benchmark_params); } ASSERT_OK_AND_ASSIGN( std::shared_ptr execution_manager, ExecutionManager::Create(tokenizer_.get(), model_resources_.get(), std::move(fake_executor), /*vision_executor_settings=*/nullptr, /*audio_executor_settings=*/nullptr, /*litert_env=*/nullptr)); ASSERT_OK_AND_ASSIGN( auto session, SessionAdvanced::Create(execution_manager, tokenizer_.get(), session_config, benchmark_info)); std::vector inputs; inputs.emplace_back(InputText("Hello World!")); absl::Status status; TaskState task_state = TaskState::kUnknown; std::vector responses; EXPECT_OK(session->RunPrefill(inputs)); ASSERT_OK_AND_ASSIGN( auto task_controller, session->RunDecodeAsync(CreateStreamingTestCallback( status, task_state, responses, /*delay_on_next=*/true))); // Cancel the process. session->CancelProcess(); // Wait for the callback to be done. EXPECT_OK(task_controller->WaitUntilDone(absl::Seconds(10))); EXPECT_OK(status); EXPECT_EQ(task_state, TaskState::kCancelled); // Generate again after cancellation. // The second generation should succeed. status = absl::OkStatus(); responses.clear(); EXPECT_OK(session->RunPrefill(inputs)); ASSERT_OK_AND_ASSIGN( task_controller, session->RunDecodeAsync(CreateStreamingTestCallback( status, task_state, responses, /*delay_on_next=*/true))); EXPECT_OK(task_controller->WaitUntilDone(absl::Seconds(10))); EXPECT_OK(status); EXPECT_EQ(task_state, TaskState::kDependentTaskCancelled); } INSTANTIATE_TEST_SUITE_P(SessionAdvancedCancellationTest, SessionAdvancedCancellationTest, testing::Bool(), testing::PrintToStringParamName()); TEST_F(SessionAdvancedTest, RunPrefillAsyncOnCancelledSession) { ASSERT_OK_AND_ASSIGN( auto fake_executor, CreateFakeLlmExecutor( // "Hello World!" /*prefill_tokens=*/{{2, 90, 547, 58, 735, 210, 466, 2294}}, // "How's it going?" /*decode_tokens=*/{ {224}, {24}, {8}, {66}, {246}, {18}, {2295}, {2294}})); const std::vector> stop_token_ids = {{2294}}; SessionConfig session_config = SessionConfig::CreateDefault(); session_config.GetMutableSamplerParams() = sampler_params_; session_config.GetMutableStopTokenIds() = stop_token_ids; session_config.SetStartTokenId(2); session_config.SetSamplerBackend(Backend::CPU); ASSERT_OK_AND_ASSIGN( std::shared_ptr execution_manager, ExecutionManager::Create(tokenizer_.get(), model_resources_.get(), std::move(fake_executor), /*vision_executor_settings=*/nullptr, /*audio_executor_settings=*/nullptr, /*litert_env=*/nullptr)); ASSERT_OK_AND_ASSIGN( auto session, SessionAdvanced::Create(execution_manager, tokenizer_.get(), session_config, /*benchmark_info=*/std::nullopt)); session->CancelProcess(); std::vector inputs; inputs.emplace_back(InputText("Hello World!")); absl::Status status; TaskState task_state; std::vector responses; // The session is cancelled, so the call should return with a kCancelled // error. EXPECT_OK(session->RunPrefillAsync( inputs, CreateStreamingTestCallback(status, task_state, responses))); // Wait for the callback to be done. EXPECT_OK(execution_manager->WaitUntilAllDone(absl::Seconds(10))); EXPECT_OK(status); EXPECT_EQ(task_state, TaskState::kDone); } TEST_F(SessionAdvancedTest, TestBenchmarkModeWithoutNumPrefillTokensRespectPromptTemplate) { const std::vector> stop_token_ids = {{2294}}; SessionConfig session_config = SessionConfig::CreateDefault(); session_config.GetMutableSamplerParams() = sampler_params_; session_config.GetMutableStopTokenIds() = stop_token_ids; session_config.SetStartTokenId(2); session_config.SetSamplerBackend(Backend::CPU); session_config.GetMutablePromptTemplates().mutable_user()->set_prefix( "User\n"); session_config.GetMutablePromptTemplates().mutable_user()->set_suffix( "\n"); session_config.GetMutablePromptTemplates().mutable_model()->set_prefix( "Model\n"); ASSERT_OK_AND_ASSIGN( auto executor, CreateFakeLlmExecutor( // Expected tokens: "User\nHello World!" + // "\nModel\n" /*prefill_tokens=*/{{2, 4, 0, 39, 637, 0, 3328, 8, 179, 90, 547, 58, 735, 210, 466, 2294}, {0, 40, 23, 0, 4, 0, 39, 637, 0, 197, 979, 3076}}, /*decode_tokens=*/{{224}})); proto::BenchmarkParams benchmark_params; BenchmarkInfo benchmark_info(benchmark_params); ASSERT_OK_AND_ASSIGN( std::shared_ptr execution_manager, ExecutionManager::Create(tokenizer_.get(), model_resources_.get(), std::move(executor), /*vision_executor_settings=*/nullptr, /*audio_executor_settings=*/nullptr, /*litert_env=*/nullptr)); ASSERT_OK_AND_ASSIGN( auto session, SessionAdvanced::Create(execution_manager, tokenizer_.get(), session_config, benchmark_info)); std::vector inputs; inputs.emplace_back(InputText("Hello World!")); EXPECT_OK(session->RunPrefill(inputs)); EXPECT_EQ(session->GetBenchmarkInfo()->GetTotalPrefillTurns(), 1); } TEST_F(SessionAdvancedTest, TestBenchmarkModeWithNumPrefillTokensIgnorePromptTemplate) { const std::vector> stop_token_ids = {{2294}}; SessionConfig session_config = SessionConfig::CreateDefault(); session_config.GetMutableSamplerParams() = sampler_params_; session_config.GetMutableStopTokenIds() = stop_token_ids; session_config.SetStartTokenId(2); session_config.SetSamplerBackend(Backend::CPU); session_config.GetMutablePromptTemplates().mutable_user()->set_prefix( "User\n"); session_config.GetMutablePromptTemplates().mutable_user()->set_suffix( "\n"); session_config.GetMutablePromptTemplates().mutable_model()->set_prefix( "Model\n"); ASSERT_OK_AND_ASSIGN( auto executor, CreateFakeLlmExecutor( // Expected tokens: "Hello World!" (No templates) /*prefill_tokens=*/{{90, 547, 58, 735, 210, 466, 2294}}, /*decode_tokens=*/{{224}})); proto::BenchmarkParams benchmark_params; benchmark_params.set_num_prefill_tokens(7); BenchmarkInfo benchmark_info(benchmark_params); ASSERT_OK_AND_ASSIGN( std::shared_ptr execution_manager, ExecutionManager::Create(tokenizer_.get(), model_resources_.get(), std::move(executor), /*vision_executor_settings=*/nullptr, /*audio_executor_settings=*/nullptr, /*litert_env=*/nullptr)); ASSERT_OK_AND_ASSIGN( auto session, SessionAdvanced::Create(execution_manager, tokenizer_.get(), session_config, benchmark_info)); std::vector inputs; inputs.emplace_back(InputText("Hello World!")); EXPECT_OK(session->RunPrefill(inputs)); EXPECT_EQ(session->GetBenchmarkInfo()->GetTotalPrefillTurns(), 1); } TEST_F(SessionAdvancedTest, PrefillAndDecodeWithConstrainedDecodingWithInternalSampler) { // Fake constraint that expects "'s it". std::vector expected_token_ids = {24, 8, 66, 0}; auto constraint = FakeConstraint(expected_token_ids, /*vocabulary_size=*/2560); const std::vector> stop_token_ids = {{2294}, {0}}; // Top P sampler. proto::SamplerParameters sampler_params; sampler_params.set_type(proto::SamplerParameters::TOP_P); sampler_params.set_k(1); sampler_params.set_temperature(1.0); sampler_params.set_p(0.5); sampler_params.set_seed(1); SessionConfig session_config = SessionConfig::CreateDefault(); session_config.GetMutableSamplerParams() = sampler_params; session_config.GetMutableStopTokenIds() = stop_token_ids; session_config.SetStartTokenId(2); ASSERT_OK_AND_ASSIGN( auto executor, CreateFakeLlmExecutor( /*prefill_tokens=*/{{2, 224}, // The first prefill. {0}}, // The expected prefill tokens that after // stop tokens are found in decoding with // sampler. That is, the last // sampled tokens at stop condition. // "How's it going?" /*decode_tokens=*/{{24}, {8}, {66}, {246}, {18}, {2295}, {2294}})); ASSERT_OK_AND_ASSIGN( std::shared_ptr execution_manager, ExecutionManager::Create(tokenizer_.get(), model_resources_.get(), std::move(executor), /*vision_executor_settings=*/nullptr, /*audio_executor_settings=*/nullptr, /*litert_env=*/nullptr)); auto session = SessionAdvanced::Create(execution_manager, tokenizer_.get(), session_config, /*benchmark_info=*/std::nullopt); std::vector inputs; inputs.emplace_back(InputText("How")); absl::Status status; TaskState task_state; std::vector texts; auto decode_config = DecodeConfig::CreateDefault(); decode_config.SetConstraint(&constraint); EXPECT_OK((*session)->RunPrefill(inputs)); ASSERT_OK_AND_ASSIGN(auto task_controller, (*session)->RunDecodeAsync( CreateStreamingTestCallback( status, task_state, texts), decode_config)); EXPECT_OK(task_controller->WaitUntilDone(absl::Seconds(10))); EXPECT_OK(status); EXPECT_EQ(task_state, TaskState::kDone); EXPECT_EQ(texts.size(), 3); EXPECT_THAT(texts, testing::ElementsAre("'", "s", " it")); } TEST_F(SessionAdvancedTest, PrefillAndDecodeWithConstrainedDecodingWithExternalSampler) { // Fake constraint that expects "'s it". std::vector expected_token_ids = {24, 8, 66, 0}; auto constraint = FakeConstraint(expected_token_ids, /*vocabulary_size=*/2560); const std::vector> stop_token_ids = {{2294}, {0}}; // Top P sampler. proto::SamplerParameters sampler_params; sampler_params.set_type(proto::SamplerParameters::TOP_P); sampler_params.set_k(1); sampler_params.set_temperature(1.0); sampler_params.set_p(0.5); sampler_params.set_seed(1); SessionConfig session_config = SessionConfig::CreateDefault(); session_config.GetMutableSamplerParams() = sampler_params; session_config.GetMutableStopTokenIds() = stop_token_ids; session_config.SetStartTokenId(2); session_config.SetUseExternalSampler(true); session_config.SetSamplerBackend(Backend::CPU); ASSERT_OK_AND_ASSIGN( auto executor, CreateFakeLlmExecutor( /*prefill_tokens=*/{{2, 224}, // The first prefill. {0}}, // The expected prefill tokens that after // stop tokens are found in decoding with // sampler. That is, the last // sampled tokens at stop condition. // "How's it going?" /*decode_tokens=*/{{24}, {8}, {66}, {246}, {18}, {2295}, {2294}})); ASSERT_OK_AND_ASSIGN( std::shared_ptr execution_manager, ExecutionManager::Create(tokenizer_.get(), model_resources_.get(), std::move(executor), /*vision_executor_settings=*/nullptr, /*audio_executor_settings=*/nullptr, /*litert_env=*/nullptr)); auto session = SessionAdvanced::Create(execution_manager, tokenizer_.get(), session_config, /*benchmark_info=*/std::nullopt); std::vector inputs; inputs.emplace_back(InputText("How")); absl::Status status; TaskState task_state; std::vector texts; auto decode_config = DecodeConfig::CreateDefault(); decode_config.SetConstraint(&constraint); EXPECT_OK((*session)->RunPrefill(inputs)); ASSERT_OK_AND_ASSIGN(auto task_controller, (*session)->RunDecodeAsync( CreateStreamingTestCallback( status, task_state, texts), decode_config)); EXPECT_OK(task_controller->WaitUntilDone(absl::Seconds(10))); EXPECT_OK(status); EXPECT_EQ(task_state, TaskState::kDone); EXPECT_EQ(texts.size(), 3); EXPECT_THAT(texts, testing::ElementsAre("'", "s", " it")); } TEST_F(SessionAdvancedTest, RunIncrementalPrefillWithDecode) { const std::vector> stop_token_ids = {{2294}}; SessionConfig session_config = SessionConfig::CreateDefault(); session_config.SetStartTokenId(2); session_config.SetSamplerBackend(Backend::CPU); session_config.GetMutableSamplerParams() = sampler_params_; session_config.GetMutableStopTokenIds() = stop_token_ids; session_config.GetMutablePromptTemplates().mutable_user()->set_prefix( "User:"); session_config.GetMutablePromptTemplates().mutable_user()->set_suffix( "[END]"); session_config.GetMutablePromptTemplates().mutable_model()->set_prefix( "Model:"); session_config.GetMutableLlmModelType().mutable_gemma3n(); ASSERT_OK_AND_ASSIGN( auto executor, CreateFakeLlmExecutor( /*prefill_tokens=*/ { {2, 423, 8, 179, 29, 207, 19, 547, 58}, // prefill chunk 1.1 {735, 210, 466, 2294}, // prefill chunk 1.2 {433, 2172, 1920, 432, 197, 979, 3076, 29}, // prefill ran before decode with turn change template {423, 8, 179, 29, 207, 19, 547, 58, 735, 210, 466, 2294}, // prefill chunk 2.1 {433, 2172, 1920, 432, 197, 979, 3076, 29}, // prefill ran before decode with turn change template }, /*decode_tokens=*/ {{1}, {2}, {3}, {2294}, {1}, {2}, {3}, {2294}})); ASSERT_OK_AND_ASSIGN( std::shared_ptr execution_manager, ExecutionManager::Create(tokenizer_.get(), model_resources_.get(), std::move(executor), /*vision_executor_settings=*/nullptr, /*audio_executor_settings=*/nullptr, /*litert_env=*/nullptr)); ASSERT_OK_AND_ASSIGN( auto session, SessionAdvanced::Create(execution_manager, tokenizer_.get(), session_config, /*benchmark_info=*/std::nullopt)); { std::vector inputs; inputs.emplace_back(InputText("Hello ")); EXPECT_OK(session->RunPrefill(inputs)); } { std::vector inputs; inputs.emplace_back(InputText("World!")); EXPECT_OK(session->RunPrefill(inputs)); } { EXPECT_OK(session->RunDecode()); } { std::vector inputs; inputs.emplace_back(InputText("Hello World!")); EXPECT_OK(session->RunPrefill(inputs)); } { EXPECT_OK(session->RunDecode()); } } #if !defined(WIN32) && !defined(_WIN32) && !defined(__WIN32__) && \ !defined(__NT__) && !defined(_WIN64) TEST_F(SessionAdvancedTest, ProcessAndCombineContentsTextAndAudioSuccess) { const std::vector> stop_token_ids = {{2294}}; SessionConfig session_config = SessionConfig::CreateDefault(); session_config.SetAudioModalityEnabled(true); session_config.SetStartTokenId(2); session_config.SetSamplerBackend(Backend::CPU); session_config.GetMutableSamplerParams() = sampler_params_; session_config.GetMutableStopTokenIds() = stop_token_ids; session_config.GetMutablePromptTemplates().mutable_user()->set_prefix( "User:"); session_config.GetMutablePromptTemplates().mutable_user()->set_suffix( "[END]"); session_config.GetMutablePromptTemplates().mutable_model()->set_prefix( "Model:"); session_config.GetMutableLlmModelType().mutable_gemma3n(); ASSERT_OK_AND_ASSIGN( auto audio_executor_settings, CreateAudioExecutorSettings((std::filesystem::path(::testing::SrcDir()) / std::string(kTestAudioModelPath)) .string(), /*max_sequence_length=*/0, Backend::CPU)); ASSERT_OK_AND_ASSIGN( auto executor, CreateFakeLlmExecutor( // "User:Hello World![END]Model:" /*prefill_tokens=*/{{2, 423, 8, 179, 29, 207, 19, 547, 58, 735, 210, 466, 2294, 256000, -2, -2, -2, -2, -2, -4}, {433, 2172, 1920, 432, 197, 979, 3076, 29}}, // "How's it going?" /*decode_tokens=*/ {{224}, {24}, {8}, {66}, {246}, {18}, {2295}, {2294}}, /*audio_embedding=*/ std::vector(kExpectedAudioEmbedding.begin(), kExpectedAudioEmbedding.end()))); LITERT_ASSERT_OK_AND_ASSIGN( auto env, Environment::Create(std::vector())); ASSERT_OK_AND_ASSIGN( std::shared_ptr execution_manager, ExecutionManager::Create( tokenizer_.get(), model_resources_.get(), std::move(executor), /*vision_executor_settings=*/nullptr, /*audio_executor_settings=*/std::move(audio_executor_settings), /*litert_env=*/&env)); ASSERT_OK_AND_ASSIGN( auto session, SessionAdvanced::Create(execution_manager, tokenizer_.get(), session_config, /*benchmark_info=*/std::nullopt)); std::vector inputs; inputs.emplace_back(InputText("Hello World!")); LITERT_ASSERT_OK_AND_ASSIGN( TensorBuffer mel_spectrogram_data, CopyToTensorBuffer( mel_spectrogram_data, {1, kSpectrogramSequenceLength, kSpectrogramFrequencySlots})); InputAudio input_audio(std::move(mel_spectrogram_data)); inputs.emplace_back(std::move(input_audio)); inputs.emplace_back(InputAudioEnd()); EXPECT_OK(session->RunPrefill(inputs)); } TEST_F(SessionAdvancedTest, ProcessAndCombineContentsTextAudioTextSuccess) { const std::vector> stop_token_ids = {{2294}}; SessionConfig session_config = SessionConfig::CreateDefault(); session_config.SetAudioModalityEnabled(true); session_config.SetStartTokenId(2); session_config.SetSamplerBackend(Backend::CPU); session_config.GetMutableSamplerParams() = sampler_params_; session_config.GetMutableStopTokenIds() = stop_token_ids; session_config.GetMutablePromptTemplates().mutable_user()->set_prefix( "User:"); session_config.GetMutablePromptTemplates().mutable_user()->set_suffix( "[END]"); session_config.GetMutablePromptTemplates().mutable_model()->set_prefix( "Model:"); session_config.GetMutableLlmModelType().mutable_gemma3n(); ASSERT_OK_AND_ASSIGN( auto audio_executor_settings, CreateAudioExecutorSettings((std::filesystem::path(::testing::SrcDir()) / std::string(kTestAudioModelPath)) .string(), /*max_sequence_length=*/0, Backend::CPU)); ASSERT_OK_AND_ASSIGN( auto executor, CreateFakeLlmExecutor( // "User:Hello World!What does the audio say?" // "[END]Model:" /*prefill_tokens=*/ {{2, 423, 8, 179, 29, 207, 19, 547, 58, 735, 210, 466, 2294, 256000, -2, -2, -2, -2, -2, -4, 583, 378, 844, 166, 3, 14, 1252, 54, 58, 626, 2295}, {3995, 2172, 1920, 432, 197, 979, 3076, 29}}, // "How's it going?" /*decode_tokens=*/ {{224}, {24}, {8}, {66}, {246}, {18}, {2295}, {2294}}, /*audio_embedding=*/ std::vector(kExpectedAudioEmbedding.begin(), kExpectedAudioEmbedding.end()))); LITERT_ASSERT_OK_AND_ASSIGN( auto env, Environment::Create(std::vector())); ASSERT_OK_AND_ASSIGN( std::shared_ptr execution_manager, ExecutionManager::Create( tokenizer_.get(), model_resources_.get(), std::move(executor), /*vision_executor_settings=*/nullptr, /*audio_executor_settings=*/std::move(audio_executor_settings), /*litert_env=*/&env)); ASSERT_OK_AND_ASSIGN( auto session, SessionAdvanced::Create(execution_manager, tokenizer_.get(), session_config, /*benchmark_info=*/std::nullopt)); std::vector inputs; inputs.emplace_back(InputText("Hello World!")); LITERT_ASSERT_OK_AND_ASSIGN( TensorBuffer mel_spectrogram_data, CopyToTensorBuffer( mel_spectrogram_data, {1, kSpectrogramSequenceLength, kSpectrogramFrequencySlots})); InputAudio input_audio(std::move(mel_spectrogram_data)); inputs.emplace_back(std::move(input_audio)); inputs.emplace_back(InputAudioEnd()); inputs.emplace_back(InputText("What does the audio say?")); EXPECT_OK(session->RunPrefill(inputs)); } #endif // !defined(WIN32) && !defined(_WIN32) && !defined(__WIN32__) && \ // !defined(__NT__) && !defined(_WIN64) TEST_F(SessionAdvancedTest, RunTextScoringEmptyTargetTextFailure) { ASSERT_OK_AND_ASSIGN(auto session, CreateTestSession()); std::vector target_text; EXPECT_THAT(session->RunTextScoring(target_text, /*store_token_lengths=*/false), StatusIs(absl::StatusCode::kInvalidArgument, "Target text size should be 1.")); } TEST_F(SessionAdvancedTest, RunTextScoringMultipleTargetTextFailure) { ASSERT_OK_AND_ASSIGN(auto session, CreateTestSession()); std::vector target_text; target_text.push_back("How's it going?"); target_text.push_back("How are you?"); EXPECT_THAT( session->RunTextScoring(target_text, /*store_token_lengths=*/false), StatusIs(absl::StatusCode::kInvalidArgument, "Target text size should be 1.")); } TEST_F(SessionAdvancedTest, RunTextScoringWithoutTokenLengthsSuccess) { ASSERT_OK_AND_ASSIGN(auto session, CreateTestSession()); std::vector inputs; inputs.emplace_back(InputText("Hello World!")); EXPECT_OK(session->RunPrefill(inputs)); std::vector target_texts; target_texts.push_back("How's it going?"); const auto responses = session->RunTextScoring(target_texts, /*store_token_lengths=*/false); EXPECT_OK(responses); // Expect a single output candidate with score 0.0f. EXPECT_EQ(responses->GetScores().size(), 1); EXPECT_EQ(responses->GetScores()[0], 0.0f); EXPECT_FALSE(responses->GetTokenLengths().has_value()); } TEST_F(SessionAdvancedTest, RunTextScoringWithTokenLengthsSuccess) { ASSERT_OK_AND_ASSIGN(auto session, CreateTestSession()); std::vector inputs; inputs.emplace_back(InputText("Hello World!")); EXPECT_OK(session->RunPrefill(inputs)); std::vector target_texts; target_texts.push_back("How's it going?"); const auto responses = session->RunTextScoring(target_texts, /*store_token_lengths=*/true); EXPECT_OK(responses); // Expect a single output candidate with score 0.0f and token length 7. EXPECT_EQ(responses->GetScores().size(), 1); EXPECT_EQ(responses->GetScores()[0], 0.0f); EXPECT_TRUE(responses->GetTokenLengths().has_value()); EXPECT_EQ(responses->GetTokenLengths()->size(), 1); EXPECT_EQ((*responses->GetTokenLengths())[0], 7); } TEST_F(SessionAdvancedTest, RunTextScoringAsyncEmptyTargetTextFailure) { ASSERT_OK_AND_ASSIGN(auto session, CreateTestSession()); std::vector target_text; auto controller = session->RunTextScoringAsync( target_text, [](absl::StatusOr r) {}, /*store_token_lengths=*/false); EXPECT_THAT(controller.status(), StatusIs(absl::StatusCode::kInvalidArgument, "Target text size should be 1.")); } TEST_F(SessionAdvancedTest, RunTextScoringAsyncMultipleTargetTextFailure) { ASSERT_OK_AND_ASSIGN(auto session, CreateTestSession()); std::vector target_text; target_text.push_back("How's it going?"); target_text.push_back("How are you?"); auto controller = session->RunTextScoringAsync( target_text, [](absl::StatusOr r) {}, /*store_token_lengths=*/false); EXPECT_THAT(controller.status(), StatusIs(absl::StatusCode::kInvalidArgument, "Target text size should be 1.")); } TEST_F(SessionAdvancedTest, RunTextScoringAsyncWithoutTokenLengthsSuccess) { ASSERT_OK_AND_ASSIGN(auto session, CreateTestSession()); std::vector inputs; inputs.emplace_back(InputText("Hello World!")); EXPECT_OK(session->RunPrefill(inputs)); std::vector target_texts; target_texts.push_back("How's it going?"); absl::Status status; std::optional responses; ASSERT_OK_AND_ASSIGN(auto controller, session->RunTextScoringAsync( target_texts, [&](absl::StatusOr r) { if (!r.ok()) { status = r.status(); return; } if (IsTaskEndState(r->GetTaskState())) { responses.emplace(*std::move(r)); } }, /*store_token_lengths=*/false)); EXPECT_OK(controller->WaitUntilDone(absl::Seconds(10))); EXPECT_OK(status); ASSERT_TRUE(responses.has_value()); // Expect a single output candidate with score 0.0f. EXPECT_EQ(responses->GetScores().size(), 1); EXPECT_EQ(responses->GetScores()[0], 0.0f); EXPECT_FALSE(responses->GetTokenLengths().has_value()); } TEST_F(SessionAdvancedTest, RunTextScoringAsyncWithTokenLengthsSuccess) { ASSERT_OK_AND_ASSIGN(auto session, CreateTestSession()); std::vector inputs; inputs.emplace_back(InputText("Hello World!")); EXPECT_OK(session->RunPrefill(inputs)); std::vector target_texts; target_texts.push_back("How's it going?"); absl::Status status; std::optional responses; ASSERT_OK_AND_ASSIGN(auto controller, session->RunTextScoringAsync( target_texts, [&](absl::StatusOr r) { if (!r.ok()) { status = r.status(); return; } if (IsTaskEndState(r->GetTaskState())) { responses.emplace(*std::move(r)); } }, /*store_token_lengths=*/true)); EXPECT_OK(controller->WaitUntilDone(absl::Seconds(10))); EXPECT_OK(status); ASSERT_TRUE(responses.has_value()); // Expect a single output candidate with score 0.0f and token length 7. EXPECT_EQ(responses->GetScores().size(), 1); EXPECT_EQ(responses->GetScores()[0], 0.0f); EXPECT_TRUE(responses->GetTokenLengths().has_value()); EXPECT_EQ(responses->GetTokenLengths()->size(), 1); EXPECT_EQ((*responses->GetTokenLengths())[0], 7); } } // namespace } // namespace litert::lm