Spaces:
Running
Running
| // Copyright 2025 The ODML Authors. | |
| // | |
| // Licensed under the Apache License, Version 2.0 (the "License"); | |
| // you may not use this file except in compliance with the License. | |
| // You may obtain a copy of the License at | |
| // | |
| // http://www.apache.org/licenses/LICENSE-2.0 | |
| // | |
| // Unless required by applicable law or agreed to in writing, software | |
| // distributed under the License is distributed on an "AS IS" BASIS, | |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| // See the License for the specific language governing permissions and | |
| // limitations under the License. | |
| namespace litert::lm { | |
| namespace { | |
| using ::testing::status::StatusIs; | |
| constexpr absl::string_view kTestdataDir = | |
| "litert_lm/runtime/components/testdata/"; | |
| constexpr absl::string_view kTestAudioModelPath = | |
| "litert_lm/runtime/testdata/dummy_audio_only.litertlm"; | |
| constexpr int kSpectrogramFrequencySlots = 8; | |
| constexpr int kSpectrogramSequenceLength = 10; | |
| constexpr int kEmbeddingSequenceLength = 5; | |
| constexpr int kEmbeddingDimensions = 6; | |
| // Audio embedding tensor will have shape [1, kEmbeddingSequenceLength, | |
| // kEmbeddingDimensions]. | |
| constexpr std::array<float, kEmbeddingSequenceLength * kEmbeddingDimensions> | |
| kExpectedAudioEmbedding = {0., 0., 0., 0., 0., 0., 0., 1., 2., 3., | |
| 3., 3., 0., 1., 2., 4., 4., 4., 1., 2., | |
| 3., 5., 5., 5., 0., 1., 2., 4., 4., 4.}; | |
| // Mel spectrogram tensor will have shape [1, kSpectrogramSequenceLength, | |
| // kSpectrogramFrequencySlots]. | |
| constexpr std::array<float, | |
| kSpectrogramSequenceLength * kSpectrogramFrequencySlots> | |
| mel_spectrogram_data = { | |
| 0., 0., 0., 0., 0., 0., 1., 0., 1., 1., 1., 1., 0., 0., 0., 0., | |
| 0., 1., 0., 0., 1., 1., 1., 1., 0., 1., 0., 0., 0., 0., 0., 0., | |
| 0., 1., 0., 1., 0., 0., 1., 1., 1., 1., 1., 0., 0., 1., 1., 0., | |
| 1., 0., 0., 1., 0., 1., 0., 1., 1., 0., 0., 1., 0., 1., 0., 0., | |
| 0., 1., 0., 1., 1., 0., 1., 0., 0., 0., 1., 0., 1., 1., 1., 1.}; | |
| absl::StatusOr<std::unique_ptr<FakeLlmExecutor>> CreateFakeLlmExecutor( | |
| std::vector<std::vector<int>> prefill_tokens, | |
| std::vector<std::vector<int>> decode_tokens, | |
| std::optional<std::vector<float>> audio_embedding = std::nullopt) { | |
| auto batch_size = decode_tokens.empty() ? 1 : decode_tokens[0].size(); | |
| auto fake_executor = std::make_unique<FakeLlmExecutor>( | |
| 2560, prefill_tokens, decode_tokens, batch_size, audio_embedding); | |
| return std::move(fake_executor); | |
| } | |
| class ExtendedTokenizer : public Tokenizer { | |
| public: | |
| static absl::StatusOr<std::unique_ptr<ExtendedTokenizer>> CreateFromFile( | |
| absl::string_view model_path) { | |
| ASSIGN_OR_RETURN(auto tokenizer, | |
| SentencePieceTokenizer::CreateFromFile(model_path)); | |
| return absl::WrapUnique(new ExtendedTokenizer(std::move(tokenizer))); | |
| } | |
| void SetExtendedToken(int token_id, absl::string_view token_str) { | |
| extended_tokens_to_id_[token_str] = token_id; | |
| id_to_extended_tokens_[token_id] = token_str; | |
| } | |
| absl::StatusOr<std::vector<int>> TextToTokenIds( | |
| absl::string_view text) override { | |
| std::vector<int> token_ids; | |
| bool is_extended_token_found = false; | |
| do { | |
| is_extended_token_found = false; | |
| for (const auto& [extended_token_str, extended_token_id] : | |
| extended_tokens_to_id_) { | |
| auto extended_token_pos = text.find(extended_token_str); | |
| if (extended_token_pos != std::string::npos) { | |
| // The text before the extended token. | |
| ASSIGN_OR_RETURN( | |
| auto text_ids, | |
| tokenizer_->TextToTokenIds(text.substr(0, extended_token_pos))); | |
| token_ids.insert(token_ids.end(), text_ids.begin(), text_ids.end()); | |
| token_ids.push_back(extended_token_id); | |
| text = text.substr(extended_token_pos + extended_token_str.size()); | |
| is_extended_token_found = true; | |
| } | |
| } | |
| } while (is_extended_token_found); | |
| if (!text.empty()) { | |
| ASSIGN_OR_RETURN(auto text_ids, tokenizer_->TextToTokenIds(text)); | |
| token_ids.insert(token_ids.end(), text_ids.begin(), text_ids.end()); | |
| } | |
| return token_ids; | |
| } | |
| absl::StatusOr<std::string> TokenIdsToText( | |
| const std::vector<int>& token_ids) override { | |
| std::vector<std::string> token_strs; | |
| for (int token_id : token_ids) { | |
| if (id_to_extended_tokens_.contains(token_id)) { | |
| token_strs.push_back(id_to_extended_tokens_[token_id]); | |
| } else { | |
| token_strs.push_back(tokenizer_->TokenIdsToText({token_id}).value()); | |
| } | |
| } | |
| return absl::StrJoin(token_strs, ""); | |
| } | |
| absl::StatusOr<int> TokenToId(absl::string_view token) override { | |
| if (extended_tokens_to_id_.contains(token)) { | |
| return extended_tokens_to_id_[token]; | |
| } | |
| return tokenizer_->TokenToId(token); | |
| } | |
| TokenizerType GetTokenizerType() const override { | |
| return tokenizer_->GetTokenizerType(); | |
| } | |
| std::vector<std::string> GetTokens() const override { | |
| return tokenizer_->GetTokens(); | |
| } | |
| private: | |
| explicit ExtendedTokenizer(std::unique_ptr<SentencePieceTokenizer> tokenizer) | |
| : tokenizer_(std::move(tokenizer)) {}; | |
| absl::flat_hash_map<int, std::string> id_to_extended_tokens_; | |
| absl::flat_hash_map<std::string, int> extended_tokens_to_id_; | |
| std::unique_ptr<SentencePieceTokenizer> tokenizer_; | |
| }; | |
| class SessionAdvancedTest : public testing::Test { | |
| protected: | |
| void SetUp() override { | |
| auto tokenizer = ExtendedTokenizer::CreateFromFile( | |
| (std::filesystem::path(::testing::SrcDir()) / | |
| std::string(kTestdataDir) / "sentencepiece.model") | |
| .string()); | |
| ASSERT_OK(tokenizer); | |
| tokenizer.value()->SetExtendedToken(256000, "<start_of_audio>"); | |
| tokenizer_ = std::move(*tokenizer); | |
| model_resources_ = std::unique_ptr<ModelResources>(); | |
| sampler_params_.set_type(proto::SamplerParameters::TYPE_UNSPECIFIED); | |
| } | |
| absl::StatusOr<std::unique_ptr<SessionAdvanced>> CreateTestSession() { | |
| const std::vector<std::vector<int>> stop_token_ids = {{2294}}; | |
| SessionConfig session_config = SessionConfig::CreateDefault(); | |
| session_config.GetMutableSamplerParams() = sampler_params_; | |
| session_config.GetMutableStopTokenIds() = stop_token_ids; | |
| session_config.SetStartTokenId(2); | |
| session_config.SetSamplerBackend(Backend::CPU); | |
| ASSIGN_OR_RETURN( | |
| auto executor, | |
| CreateFakeLlmExecutor( | |
| // "Hello World!" | |
| /*prefill_tokens=*/{{2, 90, 547, 58, 735, 210, 466, 2294}}, | |
| // "How's it going?" | |
| /*decode_tokens=*/{ | |
| {224}, {24}, {8}, {66}, {246}, {18}, {2295}, {2294}})); | |
| ASSIGN_OR_RETURN( | |
| execution_manager_, | |
| ExecutionManager::Create(tokenizer_.get(), model_resources_.get(), | |
| std::move(executor), | |
| /*vision_executor_settings=*/nullptr, | |
| /*audio_executor_settings=*/nullptr, | |
| /*litert_env=*/nullptr)); | |
| return SessionAdvanced::Create(execution_manager_, tokenizer_.get(), | |
| session_config, | |
| /*benchmark_info=*/std::nullopt); | |
| } | |
| std::unique_ptr<Tokenizer> tokenizer_; | |
| std::unique_ptr<ModelResources> model_resources_; | |
| proto::SamplerParameters sampler_params_; | |
| std::shared_ptr<ExecutionManager> execution_manager_; | |
| }; | |
| absl::StatusOr<std::unique_ptr<AudioExecutorSettings>> | |
| CreateAudioExecutorSettings(const std::string& model_path, | |
| int max_sequence_length, Backend backend) { | |
| ASSIGN_OR_RETURN(auto model_file, ScopedFile::Open(model_path)); | |
| auto model_file_ptr = std::make_shared<ScopedFile>(std::move(model_file)); | |
| ASSIGN_OR_RETURN(auto model_assets, ModelAssets::Create(model_file_ptr)); | |
| // Create the audio executor settings. | |
| ASSIGN_OR_RETURN(auto audio_executor_settings, | |
| AudioExecutorSettings::CreateDefault( | |
| model_assets, max_sequence_length, backend)); | |
| return std::make_unique<AudioExecutorSettings>( | |
| std::move(audio_executor_settings)); | |
| } | |
| absl::AnyInvocable<void(absl::StatusOr<Responses>)> CreateStreamingTestCallback( | |
| absl::Status& status_ref, TaskState& state_ref, | |
| std::vector<std::string>& texts_ref, bool delay_on_next = false) { | |
| return [&status_ref, &state_ref, &texts_ref, | |
| delay_on_next](absl::StatusOr<Responses> responses) mutable { | |
| if (!responses.ok()) { | |
| status_ref = std::move(responses.status()); | |
| return; | |
| } | |
| state_ref = responses->GetTaskState(); | |
| if (IsTaskEndState(state_ref)) { | |
| return; | |
| } | |
| if (delay_on_next) { | |
| absl::SleepFor(absl::Milliseconds(50)); | |
| } | |
| if (!responses->GetTexts().empty()) { | |
| EXPECT_EQ(responses->GetTexts().size(), 1); | |
| texts_ref.push_back(responses->GetTexts()[0]); | |
| } | |
| }; | |
| } | |
| TEST_F(SessionAdvancedTest, RunPrefill) { | |
| const std::vector<std::vector<int>> stop_token_ids = {{2294}}; | |
| SessionConfig session_config = SessionConfig::CreateDefault(); | |
| session_config.GetMutableSamplerParams() = sampler_params_; | |
| session_config.GetMutableStopTokenIds() = stop_token_ids; | |
| session_config.SetStartTokenId(2); | |
| session_config.SetSamplerBackend(Backend::CPU); | |
| ASSERT_OK_AND_ASSIGN( | |
| auto executor, | |
| CreateFakeLlmExecutor( | |
| // The prefill tokens are the expected tokens that will be passed in | |
| // at each time the Prefill function is called. The values are the | |
| // token ids of the input prompt "Hello World!". | |
| // The decode tokens are the expected tokens that will be returned | |
| // by the Decode function. The values are the token ids of the | |
| // output response "How's it going?" followed by the stop token id | |
| // (2294). | |
| /*prefill_tokens=*/{{2, 90, 547, 58, 735, 210, 466, 2294}}, | |
| /*decode_tokens=*/{ | |
| {224}, {24}, {8}, {66}, {246}, {18}, {2295}, {2294}})); | |
| ASSERT_OK_AND_ASSIGN( | |
| std::shared_ptr<ExecutionManager> execution_manager, | |
| ExecutionManager::Create(tokenizer_.get(), model_resources_.get(), | |
| std::move(executor), | |
| /*vision_executor_settings=*/nullptr, | |
| /*audio_executor_settings=*/nullptr, | |
| /*litert_env=*/nullptr)); | |
| ASSERT_OK_AND_ASSIGN( | |
| auto session, | |
| SessionAdvanced::Create(execution_manager, tokenizer_.get(), | |
| session_config, /*benchmark_info=*/std::nullopt)); | |
| std::vector<InputData> inputs; | |
| inputs.emplace_back(InputText("Hello World!")); | |
| EXPECT_OK(session->RunPrefill(inputs)); | |
| } | |
| TEST_F(SessionAdvancedTest, EmptyInputTextReturnsError) { | |
| SessionConfig session_config = SessionConfig::CreateDefault(); | |
| session_config.SetSamplerBackend(Backend::CPU); | |
| ASSERT_OK_AND_ASSIGN(auto executor, CreateFakeLlmExecutor( | |
| /*prefill_tokens=*/{{}}, | |
| /*decode_tokens=*/{{}})); | |
| ASSERT_OK_AND_ASSIGN( | |
| std::shared_ptr<ExecutionManager> execution_manager, | |
| ExecutionManager::Create(tokenizer_.get(), model_resources_.get(), | |
| std::move(executor), | |
| /*vision_executor_settings=*/nullptr, | |
| /*audio_executor_settings=*/nullptr, | |
| /*litert_env=*/nullptr)); | |
| ASSERT_OK_AND_ASSIGN( | |
| auto session, | |
| SessionAdvanced::Create(execution_manager, tokenizer_.get(), | |
| session_config, /*benchmark_info=*/std::nullopt)); | |
| std::vector<InputData> inputs; | |
| inputs.emplace_back(InputText("")); | |
| EXPECT_THAT(session->RunPrefill(inputs), | |
| StatusIs(absl::StatusCode::kInvalidArgument, | |
| "No token IDs found in preprocessed_contents.")); | |
| } | |
| TEST_F(SessionAdvancedTest, RunDecodeWithInternalSampler) { | |
| const std::vector<std::vector<int>> stop_token_ids = {{2294}}; | |
| SessionConfig session_config = SessionConfig::CreateDefault(); | |
| session_config.GetMutableSamplerParams() = sampler_params_; | |
| session_config.GetMutableStopTokenIds() = stop_token_ids; | |
| session_config.SetStartTokenId(2); | |
| ASSERT_OK_AND_ASSIGN( | |
| auto executor, | |
| CreateFakeLlmExecutor( | |
| // "Hello World!" | |
| /*prefill_tokens=*/{{2, 90, 547, 58, 735, 210, 466, 2294}}, | |
| // "How's it going?" | |
| /*decode_tokens=*/{ | |
| {224}, {24}, {8}, {66}, {246}, {18}, {2295}, {2294}})); | |
| ASSERT_OK_AND_ASSIGN( | |
| std::shared_ptr<ExecutionManager> execution_manager, | |
| ExecutionManager::Create(tokenizer_.get(), model_resources_.get(), | |
| std::move(executor), | |
| /*vision_executor_settings=*/nullptr, | |
| /*audio_executor_settings=*/nullptr, | |
| /*litert_env=*/nullptr)); | |
| ASSERT_OK_AND_ASSIGN( | |
| auto session, | |
| SessionAdvanced::Create(execution_manager, tokenizer_.get(), | |
| session_config, /*benchmark_info=*/std::nullopt)); | |
| std::vector<InputData> inputs; | |
| inputs.emplace_back(InputText("Hello World!")); | |
| EXPECT_OK(session->RunPrefill(inputs)); | |
| auto responses = session->RunDecode(); | |
| EXPECT_OK(responses); | |
| // Expect a single output candidate. | |
| EXPECT_EQ(responses->GetTexts().size(), 1); | |
| // The response is " How's it going?" since "!" is the stop token which is | |
| // not included in the response. | |
| EXPECT_EQ(responses->GetTexts()[0], " How's it going?"); | |
| } | |
| TEST_F(SessionAdvancedTest, RunDecodeWithMaxOutputTokens) { | |
| const std::vector<std::vector<int>> stop_token_ids = {{2294}}; | |
| SessionConfig session_config = SessionConfig::CreateDefault(); | |
| session_config.GetMutableSamplerParams() = sampler_params_; | |
| session_config.GetMutableStopTokenIds() = stop_token_ids; | |
| session_config.SetStartTokenId(2); | |
| ASSERT_OK_AND_ASSIGN( | |
| auto executor, | |
| CreateFakeLlmExecutor( | |
| // "Hello World!" | |
| /*prefill_tokens=*/{{2, 90, 547, 58, 735, 210, 466, 2294}}, | |
| // "How's it going?" | |
| /*decode_tokens=*/{ | |
| {224}, {24}, {8}, {66}, {246}, {18}, {2295}, {2294}})); | |
| ASSERT_OK_AND_ASSIGN( | |
| std::shared_ptr<ExecutionManager> execution_manager, | |
| ExecutionManager::Create(tokenizer_.get(), model_resources_.get(), | |
| std::move(executor), | |
| /*vision_executor_settings=*/nullptr, | |
| /*audio_executor_settings=*/nullptr, | |
| /*litert_env=*/nullptr)); | |
| ASSERT_OK_AND_ASSIGN( | |
| auto session, | |
| SessionAdvanced::Create(execution_manager, tokenizer_.get(), | |
| session_config, /*benchmark_info=*/std::nullopt)); | |
| std::vector<InputData> inputs; | |
| inputs.emplace_back(InputText("Hello World!")); | |
| EXPECT_OK(session->RunPrefill(inputs)); | |
| auto decode_config = DecodeConfig::CreateDefault(); | |
| decode_config.SetMaxOutputTokens(2); | |
| auto responses = session->RunDecode(decode_config); | |
| EXPECT_OK(responses); | |
| // Expect a single output candidate. | |
| EXPECT_EQ(responses->GetTexts().size(), 1); | |
| EXPECT_EQ(responses->GetTexts()[0], " How'"); | |
| } | |
| TEST_F(SessionAdvancedTest, RunDecodeWithExternalSampler) { | |
| const std::vector<std::vector<int>> stop_token_ids = {{2294}}; | |
| SessionConfig session_config = SessionConfig::CreateDefault(); | |
| session_config.GetMutableSamplerParams() = sampler_params_; | |
| session_config.GetMutableStopTokenIds() = stop_token_ids; | |
| session_config.SetStartTokenId(2); | |
| session_config.SetUseExternalSampler(true); | |
| session_config.SetSamplerBackend(Backend::CPU); | |
| ASSERT_OK_AND_ASSIGN( | |
| auto executor, | |
| CreateFakeLlmExecutor( | |
| // "Hello World!" | |
| /*prefill_tokens=*/{{2, 90, 547, 58, 735, 210, 466, 2294}}, | |
| // "How's it going?" | |
| /*decode_tokens=*/{ | |
| {224}, {24}, {8}, {66}, {246}, {18}, {2295}, {2294}})); | |
| ASSERT_OK_AND_ASSIGN( | |
| std::shared_ptr<ExecutionManager> execution_manager, | |
| ExecutionManager::Create(tokenizer_.get(), model_resources_.get(), | |
| std::move(executor), | |
| /*vision_executor_settings=*/nullptr, | |
| /*audio_executor_settings=*/nullptr, | |
| /*litert_env=*/nullptr)); | |
| ASSERT_OK_AND_ASSIGN( | |
| auto session, | |
| SessionAdvanced::Create(execution_manager, tokenizer_.get(), | |
| session_config, /*benchmark_info=*/std::nullopt)); | |
| std::vector<InputData> inputs; | |
| inputs.emplace_back(InputText("Hello World!")); | |
| EXPECT_OK(session->RunPrefill(inputs)); | |
| auto responses = session->RunDecode(); | |
| EXPECT_OK(responses); | |
| // Expect a single output candidate. | |
| EXPECT_EQ(responses->GetTexts().size(), 1); | |
| // The response is " How's it going?" since "!" is the stop token which is | |
| // not included in the response. | |
| EXPECT_EQ(responses->GetTexts()[0], " How's it going?"); | |
| } | |
| TEST_F(SessionAdvancedTest, | |
| RunDecodeWithMultipleOutputCandidatesWithInternalSampler) { | |
| const std::vector<std::vector<int>> stop_token_ids = {{2294}}; | |
| SessionConfig session_config = SessionConfig::CreateDefault(); | |
| session_config.GetMutableSamplerParams() = sampler_params_; | |
| session_config.GetMutableStopTokenIds() = stop_token_ids; | |
| session_config.SetStartTokenId(2); | |
| session_config.SetNumOutputCandidates(3); | |
| ASSERT_OK_AND_ASSIGN( | |
| auto executor, | |
| CreateFakeLlmExecutor( | |
| // "Hello World!" | |
| /*prefill_tokens=*/{{2, 90, 547, 58, 735, 210, 466, 2294}}, | |
| // "How's it going?", "Hello World", "How's it going?" | |
| /*decode_tokens=*/{{224, 90, 224}, | |
| {24, 547, 24}, | |
| {8, 58, 8}, | |
| {66, 735, 66}, | |
| {246, 210, 246}, | |
| {18, 466, 18}, | |
| {2295, 2294, 2295}, | |
| {2294, 0, 2294}})); | |
| ASSERT_OK_AND_ASSIGN( | |
| std::shared_ptr<ExecutionManager> execution_manager, | |
| ExecutionManager::Create(tokenizer_.get(), model_resources_.get(), | |
| std::move(executor), | |
| /*vision_executor_settings=*/nullptr, | |
| /*audio_executor_settings=*/nullptr, | |
| /*litert_env=*/nullptr)); | |
| ASSERT_OK_AND_ASSIGN( | |
| auto session, | |
| SessionAdvanced::Create(execution_manager, tokenizer_.get(), | |
| session_config, /*benchmark_info=*/std::nullopt)); | |
| std::vector<InputData> inputs; | |
| inputs.emplace_back(InputText("Hello World!")); | |
| EXPECT_OK(session->RunPrefill(inputs)); | |
| auto responses = session->RunDecode(); | |
| EXPECT_OK(responses); | |
| EXPECT_EQ(responses->GetTexts().size(), 3); | |
| // The response is " How's it going?" since "!" is the stop token which is | |
| // not included in the response. | |
| EXPECT_EQ(responses->GetTexts()[0], " How's it going?"); | |
| EXPECT_EQ(responses->GetTexts()[1], " Hello World"); | |
| EXPECT_EQ(responses->GetTexts()[2], " How's it going?"); | |
| } | |
| TEST_F(SessionAdvancedTest, | |
| RunDecodeWithMultipleOutputCandidatesWithExternalSampler) { | |
| const std::vector<std::vector<int>> stop_token_ids = {{2294}}; | |
| SessionConfig session_config = SessionConfig::CreateDefault(); | |
| session_config.GetMutableSamplerParams() = sampler_params_; | |
| session_config.GetMutableStopTokenIds() = stop_token_ids; | |
| session_config.SetStartTokenId(2); | |
| session_config.SetNumOutputCandidates(3); | |
| session_config.SetUseExternalSampler(true); | |
| session_config.SetSamplerBackend(Backend::CPU); | |
| ASSERT_OK_AND_ASSIGN( | |
| auto executor, | |
| CreateFakeLlmExecutor( | |
| // "Hello World!" | |
| /*prefill_tokens=*/{{2, 90, 547, 58, 735, 210, 466, 2294}}, | |
| // "How's it going?", "Hello World", "How's it going?" | |
| /*decode_tokens=*/{{224, 90, 224}, | |
| {24, 547, 24}, | |
| {8, 58, 8}, | |
| {66, 735, 66}, | |
| {246, 210, 246}, | |
| {18, 466, 18}, | |
| {2295, 2294, 2295}, | |
| {2294, 0, 2294}})); | |
| ASSERT_OK_AND_ASSIGN( | |
| std::shared_ptr<ExecutionManager> execution_manager, | |
| ExecutionManager::Create(tokenizer_.get(), model_resources_.get(), | |
| std::move(executor), | |
| /*vision_executor_settings=*/nullptr, | |
| /*audio_executor_settings=*/nullptr, | |
| /*litert_env=*/nullptr)); | |
| ASSERT_OK_AND_ASSIGN( | |
| auto session, | |
| SessionAdvanced::Create(execution_manager, tokenizer_.get(), | |
| session_config, /*benchmark_info=*/std::nullopt)); | |
| std::vector<InputData> inputs; | |
| inputs.emplace_back(InputText("Hello World!")); | |
| EXPECT_OK(session->RunPrefill(inputs)); | |
| auto responses = session->RunDecode(); | |
| EXPECT_OK(responses); | |
| EXPECT_EQ(responses->GetTexts().size(), 3); | |
| // The response is " How's it going?" since "!" is the stop token which is | |
| // not included in the response. | |
| EXPECT_EQ(responses->GetTexts()[0], " How's it going?"); | |
| EXPECT_EQ(responses->GetTexts()[1], " Hello World"); | |
| EXPECT_EQ(responses->GetTexts()[2], " How's it going?"); | |
| } | |
| TEST_F(SessionAdvancedTest, | |
| RunDecodeWithConstrainedDecodingWithInternalSampler) { | |
| // Fake constraint that expects "'s it". | |
| std::vector<int> expected_token_ids = {24, 8, 66, 0}; | |
| auto constraint = | |
| FakeConstraint(expected_token_ids, /*vocabulary_size=*/2560); | |
| const std::vector<std::vector<int>> stop_token_ids = {{2294}, {0}}; | |
| // Top P sampler. | |
| proto::SamplerParameters sampler_params; | |
| sampler_params.set_type(proto::SamplerParameters::TOP_P); | |
| sampler_params.set_k(1); | |
| sampler_params.set_temperature(1.0); | |
| sampler_params.set_p(0.5); | |
| sampler_params.set_seed(1); | |
| SessionConfig session_config = SessionConfig::CreateDefault(); | |
| session_config.GetMutableSamplerParams() = sampler_params; | |
| session_config.GetMutableStopTokenIds() = stop_token_ids; | |
| session_config.SetStartTokenId(2); | |
| ASSERT_OK_AND_ASSIGN( | |
| auto executor, | |
| CreateFakeLlmExecutor( | |
| /*prefill_tokens=*/{{2, 224}, // The first prefill. | |
| {0}}, // The expected prefill tokens that after | |
| // stop tokens are found in decoding with | |
| // sampler. That is, the last | |
| // sampled tokens at stop condition. | |
| // "How's it going?" | |
| /*decode_tokens=*/{{24}, {8}, {66}, {246}, {18}, {2295}, {2294}})); | |
| ASSERT_OK_AND_ASSIGN( | |
| std::shared_ptr<ExecutionManager> execution_manager, | |
| ExecutionManager::Create(tokenizer_.get(), model_resources_.get(), | |
| std::move(executor), | |
| /*vision_executor_settings=*/nullptr, | |
| /*audio_executor_settings=*/nullptr, | |
| /*litert_env=*/nullptr)); | |
| ASSERT_OK_AND_ASSIGN( | |
| auto session, | |
| SessionAdvanced::Create(execution_manager, tokenizer_.get(), | |
| session_config, /*benchmark_info=*/std::nullopt)); | |
| std::vector<InputData> inputs; | |
| inputs.emplace_back(InputText("How")); | |
| auto decode_config = DecodeConfig::CreateDefault(); | |
| decode_config.SetConstraint(&constraint); | |
| EXPECT_OK(session->RunPrefill(inputs)); | |
| ASSERT_OK_AND_ASSIGN(auto responses, session->RunDecode(decode_config)); | |
| // Expect a single output candidate. | |
| EXPECT_EQ(responses.GetTexts().size(), 1); | |
| EXPECT_EQ(responses.GetTexts()[0], "'s it"); | |
| } | |
| TEST_F(SessionAdvancedTest, | |
| RunDecodeWithConstrainedDecodingWithExternalSampler) { | |
| // Fake constraint that expects "'s it". | |
| std::vector<int> expected_token_ids = {24, 8, 66, 0}; | |
| auto constraint = | |
| FakeConstraint(expected_token_ids, /*vocabulary_size=*/2560); | |
| const std::vector<std::vector<int>> stop_token_ids = {{2294}, {0}}; | |
| // Top P sampler. | |
| proto::SamplerParameters sampler_params; | |
| sampler_params.set_type(proto::SamplerParameters::TOP_P); | |
| sampler_params.set_k(1); | |
| sampler_params.set_temperature(1.0); | |
| sampler_params.set_p(0.5); | |
| sampler_params.set_seed(1); | |
| SessionConfig session_config = SessionConfig::CreateDefault(); | |
| session_config.GetMutableSamplerParams() = sampler_params; | |
| session_config.GetMutableStopTokenIds() = stop_token_ids; | |
| session_config.SetStartTokenId(2); | |
| session_config.SetUseExternalSampler(true); | |
| session_config.SetSamplerBackend(Backend::CPU); | |
| ASSERT_OK_AND_ASSIGN( | |
| auto executor, | |
| CreateFakeLlmExecutor( | |
| /*prefill_tokens=*/{{2, 224}, // The first prefill. | |
| {0}}, // The expected prefill tokens that after | |
| // stop tokens are found in decoding with | |
| // sampler. That is, the last | |
| // sampled tokens at stop condition. | |
| // "How's it going?" | |
| /*decode_tokens=*/{{24}, {8}, {66}, {246}, {18}, {2295}, {2294}})); | |
| ASSERT_OK_AND_ASSIGN( | |
| std::shared_ptr<ExecutionManager> execution_manager, | |
| ExecutionManager::Create(tokenizer_.get(), model_resources_.get(), | |
| std::move(executor), | |
| /*vision_executor_settings=*/nullptr, | |
| /*audio_executor_settings=*/nullptr, | |
| /*litert_env=*/nullptr)); | |
| ASSERT_OK_AND_ASSIGN( | |
| auto session, | |
| SessionAdvanced::Create(execution_manager, tokenizer_.get(), | |
| session_config, /*benchmark_info=*/std::nullopt)); | |
| std::vector<InputData> inputs; | |
| inputs.emplace_back(InputText("How")); | |
| auto decode_config = DecodeConfig::CreateDefault(); | |
| decode_config.SetConstraint(&constraint); | |
| EXPECT_OK(session->RunPrefill(inputs)); | |
| ASSERT_OK_AND_ASSIGN(auto responses, session->RunDecode(decode_config)); | |
| // Expect a single output candidate. | |
| EXPECT_EQ(responses.GetTexts().size(), 1); | |
| EXPECT_EQ(responses.GetTexts()[0], "'s it"); | |
| } | |
| absl::AnyInvocable<void(absl::StatusOr<Responses>)> CreateTestCallback( | |
| bool& done_ref) { | |
| return [&done_ref](absl::StatusOr<Responses> responses) mutable { | |
| if (responses.ok() && responses->GetTexts().empty()) { | |
| done_ref = true; | |
| } | |
| }; | |
| } | |
| TEST_F(SessionAdvancedTest, RunPrefillAsync) { | |
| const std::vector<std::vector<int>> stop_token_ids = {{2294}}; | |
| SessionConfig session_config = SessionConfig::CreateDefault(); | |
| session_config.GetMutableSamplerParams() = sampler_params_; | |
| session_config.SetStartTokenId(2); | |
| session_config.GetMutableStopTokenIds() = stop_token_ids; | |
| session_config.SetSamplerBackend(Backend::CPU); | |
| ASSERT_OK_AND_ASSIGN( | |
| auto executor, | |
| CreateFakeLlmExecutor( | |
| // "Hello World!" | |
| /*prefill_tokens=*/{{2, 90, 547, 58, 735, 210, 466, 2294}}, | |
| // "How's it going?" | |
| /*decode_tokens=*/{ | |
| {224}, {24}, {8}, {66}, {246}, {18}, {2295}, {2294}})); | |
| ASSERT_OK_AND_ASSIGN( | |
| std::shared_ptr<ExecutionManager> execution_manager, | |
| ExecutionManager::Create(tokenizer_.get(), model_resources_.get(), | |
| std::move(executor), | |
| /*vision_executor_settings=*/nullptr, | |
| /*audio_executor_settings=*/nullptr, | |
| /*litert_env=*/nullptr)); | |
| ASSERT_OK_AND_ASSIGN( | |
| auto session, | |
| SessionAdvanced::Create(execution_manager, tokenizer_.get(), | |
| session_config, /*benchmark_info=*/std::nullopt)); | |
| std::vector<InputData> inputs; | |
| inputs.emplace_back(InputText("Hello World!")); | |
| bool done = false; | |
| auto callback = CreateTestCallback(done); | |
| EXPECT_OK(session->RunPrefillAsync(inputs, std::move(callback))); | |
| // Wait for the async call to finish. | |
| EXPECT_OK(execution_manager->WaitUntilAllDone(absl::Seconds(100))); | |
| EXPECT_TRUE(done); | |
| } | |
| TEST_F(SessionAdvancedTest, RunDecodeAsyncWithInternalSampler) { | |
| const std::vector<std::vector<int>> stop_token_ids = {{2294}}; | |
| SessionConfig session_config = SessionConfig::CreateDefault(); | |
| session_config.GetMutableSamplerParams() = sampler_params_; | |
| session_config.SetStartTokenId(2); | |
| session_config.GetMutableStopTokenIds() = stop_token_ids; | |
| ASSERT_OK_AND_ASSIGN( | |
| auto executor, | |
| CreateFakeLlmExecutor( | |
| // "Hello World!" | |
| /*prefill_tokens=*/{{2, 90, 547, 58, 735, 210, 466, 2294}}, | |
| // "How's it going?" | |
| /*decode_tokens=*/{ | |
| {224}, {24}, {8}, {66}, {246}, {18}, {2295}, {2294}})); | |
| ASSERT_OK_AND_ASSIGN( | |
| std::shared_ptr<ExecutionManager> execution_manager, | |
| ExecutionManager::Create(tokenizer_.get(), model_resources_.get(), | |
| std::move(executor), | |
| /*vision_executor_settings=*/nullptr, | |
| /*audio_executor_settings=*/nullptr, | |
| /*litert_env=*/nullptr)); | |
| ASSERT_OK_AND_ASSIGN( | |
| auto session, | |
| SessionAdvanced::Create(execution_manager, tokenizer_.get(), | |
| session_config, /*benchmark_info=*/std::nullopt)); | |
| std::vector<InputData> inputs; | |
| inputs.emplace_back(InputText("Hello World!")); | |
| bool done_prefill = false; | |
| EXPECT_OK(session->RunPrefillAsync(inputs, CreateTestCallback(done_prefill))); | |
| bool done_decode = false; | |
| EXPECT_OK(session->RunDecodeAsync(CreateTestCallback(done_decode))); | |
| EXPECT_OK(execution_manager->WaitUntilAllDone(absl::Seconds(100))); | |
| EXPECT_TRUE(done_prefill); | |
| EXPECT_TRUE(done_decode); | |
| } | |
| TEST_F(SessionAdvancedTest, RunDecodeAsyncWithExternalSampler) { | |
| const std::vector<std::vector<int>> stop_token_ids = {{2294}}; | |
| SessionConfig session_config = SessionConfig::CreateDefault(); | |
| session_config.GetMutableSamplerParams() = sampler_params_; | |
| session_config.SetStartTokenId(2); | |
| session_config.GetMutableStopTokenIds() = stop_token_ids; | |
| session_config.SetUseExternalSampler(true); | |
| session_config.SetSamplerBackend(Backend::CPU); | |
| ASSERT_OK_AND_ASSIGN( | |
| auto executor, | |
| CreateFakeLlmExecutor( | |
| // "Hello World!" | |
| /*prefill_tokens=*/{{2, 90, 547, 58, 735, 210, 466, 2294}}, | |
| // "How's it going?" | |
| /*decode_tokens=*/{ | |
| {224}, {24}, {8}, {66}, {246}, {18}, {2295}, {2294}})); | |
| ASSERT_OK_AND_ASSIGN( | |
| std::shared_ptr<ExecutionManager> execution_manager, | |
| ExecutionManager::Create(tokenizer_.get(), model_resources_.get(), | |
| std::move(executor), | |
| /*vision_executor_settings=*/nullptr, | |
| /*audio_executor_settings=*/nullptr, | |
| /*litert_env=*/nullptr)); | |
| ASSERT_OK_AND_ASSIGN( | |
| auto session, SessionAdvanced::Create(execution_manager, tokenizer_.get(), | |
| session_config, | |
| /*benchmark_info=*/std::nullopt)); | |
| std::vector<InputData> inputs; | |
| inputs.emplace_back(InputText("Hello World!")); | |
| bool done_prefill = false; | |
| EXPECT_OK(session->RunPrefillAsync(inputs, CreateTestCallback(done_prefill))); | |
| bool done_decode = false; | |
| EXPECT_OK(session->RunDecodeAsync(CreateTestCallback(done_decode))); | |
| EXPECT_OK(execution_manager->WaitUntilAllDone(absl::Seconds(100))); | |
| EXPECT_TRUE(done_prefill); | |
| EXPECT_TRUE(done_decode); | |
| } | |
| TEST_F(SessionAdvancedTest, | |
| RunDecodeAsyncWithConstrainedDecodingWithInternalSampler) { | |
| // Fake constraint that expects "'s it". | |
| std::vector<int> expected_token_ids = {24, 8, 66, 0}; | |
| auto constraint = | |
| FakeConstraint(expected_token_ids, /*vocabulary_size=*/2560); | |
| const std::vector<std::vector<int>> stop_token_ids = {{2294}, {0}}; | |
| // Top P sampler. | |
| proto::SamplerParameters sampler_params; | |
| sampler_params.set_type(proto::SamplerParameters::TOP_P); | |
| sampler_params.set_k(1); | |
| sampler_params.set_temperature(1.0); | |
| sampler_params.set_p(0.5); | |
| sampler_params.set_seed(1); | |
| SessionConfig session_config = SessionConfig::CreateDefault(); | |
| session_config.GetMutableSamplerParams() = sampler_params; | |
| session_config.GetMutableStopTokenIds() = stop_token_ids; | |
| session_config.SetStartTokenId(2); | |
| ASSERT_OK_AND_ASSIGN( | |
| auto executor, | |
| CreateFakeLlmExecutor( | |
| /*prefill_tokens=*/{{2, 224}, // The first prefill. | |
| {0}}, // The expected prefill tokens that after | |
| // stop tokens are found in decoding with | |
| // sampler. That is, the last | |
| // sampled tokens at stop condition. | |
| // "How's it going?" | |
| /*decode_tokens=*/{{24}, {8}, {66}, {246}, {18}, {2295}, {2294}})); | |
| ASSERT_OK_AND_ASSIGN( | |
| std::shared_ptr<ExecutionManager> execution_manager, | |
| ExecutionManager::Create(tokenizer_.get(), model_resources_.get(), | |
| std::move(executor), | |
| /*vision_executor_settings=*/nullptr, | |
| /*audio_executor_settings=*/nullptr, | |
| /*litert_env=*/nullptr)); | |
| ASSERT_OK_AND_ASSIGN( | |
| auto session, | |
| SessionAdvanced::Create(execution_manager, tokenizer_.get(), | |
| session_config, /*benchmark_info=*/std::nullopt)); | |
| std::vector<InputData> inputs; | |
| inputs.emplace_back(InputText("How")); | |
| bool done_prefill = false; | |
| EXPECT_OK(session->RunPrefillAsync(inputs, CreateTestCallback(done_prefill))); | |
| absl::Status status; | |
| TaskState task_state = TaskState::kUnknown; | |
| std::vector<std::string> texts; | |
| auto decode_config = DecodeConfig::CreateDefault(); | |
| decode_config.SetConstraint(&constraint); | |
| ASSERT_OK_AND_ASSIGN(auto task_controller, | |
| session->RunDecodeAsync(CreateStreamingTestCallback( | |
| status, task_state, texts), | |
| decode_config)); | |
| EXPECT_OK(task_controller->WaitUntilDone(absl::Seconds(10))); | |
| EXPECT_OK(status); | |
| EXPECT_EQ(task_state, TaskState::kDone); | |
| EXPECT_EQ(texts.size(), 3); | |
| EXPECT_THAT(texts, testing::ElementsAre("'", "s", " it")); | |
| } | |
| TEST_F(SessionAdvancedTest, | |
| RunDecodeAsyncWithConstrainedDecodingWithExternalSampler) { | |
| // Fake constraint that expects "'s it". | |
| std::vector<int> expected_token_ids = {24, 8, 66, 0}; | |
| auto constraint = | |
| FakeConstraint(expected_token_ids, /*vocabulary_size=*/2560); | |
| const std::vector<std::vector<int>> stop_token_ids = {{2294}, {0}}; | |
| // Top P sampler. | |
| proto::SamplerParameters sampler_params; | |
| sampler_params.set_type(proto::SamplerParameters::TOP_P); | |
| sampler_params.set_k(1); | |
| sampler_params.set_temperature(1.0); | |
| sampler_params.set_p(0.5); | |
| sampler_params.set_seed(1); | |
| SessionConfig session_config = SessionConfig::CreateDefault(); | |
| session_config.GetMutableSamplerParams() = sampler_params; | |
| session_config.GetMutableStopTokenIds() = stop_token_ids; | |
| session_config.SetStartTokenId(2); | |
| session_config.SetUseExternalSampler(true); | |
| session_config.SetSamplerBackend(Backend::CPU); | |
| ASSERT_OK_AND_ASSIGN( | |
| auto executor, | |
| CreateFakeLlmExecutor( | |
| /*prefill_tokens=*/{{2, 224}, // The first prefill. | |
| {0}}, // The expected prefill tokens that after | |
| // stop tokens are found in decoding with | |
| // sampler. That is, the last | |
| // sampled tokens at stop condition. | |
| // "How's it going?" | |
| /*decode_tokens=*/{{24}, {8}, {66}, {246}, {18}, {2295}, {2294}})); | |
| ASSERT_OK_AND_ASSIGN( | |
| std::shared_ptr<ExecutionManager> execution_manager, | |
| ExecutionManager::Create(tokenizer_.get(), model_resources_.get(), | |
| std::move(executor), | |
| /*vision_executor_settings=*/nullptr, | |
| /*audio_executor_settings=*/nullptr, | |
| /*litert_env=*/nullptr)); | |
| ASSERT_OK_AND_ASSIGN( | |
| auto session, | |
| SessionAdvanced::Create(execution_manager, tokenizer_.get(), | |
| session_config, /*benchmark_info=*/std::nullopt)); | |
| std::vector<InputData> inputs; | |
| inputs.emplace_back(InputText("How")); | |
| bool done_prefill = false; | |
| EXPECT_OK(session->RunPrefillAsync(inputs, CreateTestCallback(done_prefill))); | |
| absl::Status status; | |
| TaskState task_state = TaskState::kUnknown; | |
| std::vector<std::string> texts; | |
| auto decode_config = DecodeConfig::CreateDefault(); | |
| decode_config.SetConstraint(&constraint); | |
| ASSERT_OK_AND_ASSIGN(auto task_controller, | |
| session->RunDecodeAsync(CreateStreamingTestCallback( | |
| status, task_state, texts), | |
| decode_config)); | |
| EXPECT_OK(task_controller->WaitUntilDone(absl::Seconds(10))); | |
| EXPECT_OK(status); | |
| EXPECT_EQ(task_state, TaskState::kDone); | |
| EXPECT_EQ(texts.size(), 3); | |
| EXPECT_THAT(texts, testing::ElementsAre("'", "s", " it")); | |
| } | |
| TEST_F(SessionAdvancedTest, SaveAndRewindCheckpoint) { | |
| ASSERT_OK_AND_ASSIGN(auto session, CreateTestSession()); | |
| std::vector<InputData> inputs; | |
| inputs.emplace_back(InputText("Hello World!")); | |
| EXPECT_OK(session->RunPrefill(inputs)); | |
| EXPECT_OK(session->SaveCheckpoint("checkpoint-1")); | |
| auto decode_config = DecodeConfig::CreateDefault(); | |
| decode_config.SetMaxOutputTokens(2); | |
| ASSERT_OK_AND_ASSIGN(auto responses1, session->RunDecode(decode_config)); | |
| EXPECT_EQ(responses1.GetTexts().size(), 1); | |
| EXPECT_EQ(responses1.GetTexts()[0], " How'"); | |
| EXPECT_OK(session->SaveCheckpoint("checkpoint-2")); | |
| EXPECT_OK(session->RewindToCheckpoint("checkpoint-1")); | |
| decode_config.SetMaxOutputTokens(2); | |
| ASSERT_OK_AND_ASSIGN(auto responses3, session->RunDecode(decode_config)); | |
| EXPECT_EQ(responses3.GetTexts().size(), 1); | |
| EXPECT_EQ(responses3.GetTexts()[0], " How'"); | |
| EXPECT_THAT(session->RewindToCheckpoint("checkpoint-2"), | |
| StatusIs(absl::StatusCode::kNotFound)); | |
| EXPECT_THAT(session->RewindToCheckpoint("non-existent"), | |
| StatusIs(absl::StatusCode::kNotFound)); | |
| } | |
| TEST_F(SessionAdvancedTest, GetCurrentStep) { | |
| ASSERT_OK_AND_ASSIGN(auto session, CreateTestSession()); | |
| // Initially step should be 0. | |
| ASSERT_OK_AND_ASSIGN(int step1, session->GetCurrentStep()); | |
| EXPECT_EQ(step1, 0); | |
| std::vector<InputData> inputs; | |
| inputs.emplace_back(InputText("Hello World!")); | |
| EXPECT_OK(session->RunPrefill(inputs)); | |
| // After prefill, step should be number of prefill tokens. | |
| // Fake executor uses 8 tokens for "Hello World!". | |
| ASSERT_OK_AND_ASSIGN(int step2, session->GetCurrentStep()); | |
| EXPECT_EQ(step2, 8); | |
| auto decode_config = DecodeConfig::CreateDefault(); | |
| decode_config.SetMaxOutputTokens(2); | |
| ASSERT_OK_AND_ASSIGN(auto responses, session->RunDecode(decode_config)); | |
| // After decode, step should increase by number of decoded tokens. | |
| ASSERT_OK_AND_ASSIGN(int step3, session->GetCurrentStep()); | |
| EXPECT_EQ(step3, 10); | |
| } | |
| TEST_F(SessionAdvancedTest, RunPrefillAndDecodeAsyncWithInternalSampler) { | |
| const std::vector<std::vector<int>> stop_token_ids = {{2294}}; | |
| SessionConfig session_config = SessionConfig::CreateDefault(); | |
| session_config.GetMutableSamplerParams() = sampler_params_; | |
| session_config.GetMutableStopTokenIds() = stop_token_ids; | |
| session_config.SetStartTokenId(2); | |
| ASSERT_OK_AND_ASSIGN( | |
| auto executor, | |
| CreateFakeLlmExecutor( | |
| // "Hello World!" | |
| /*prefill_tokens=*/{{2, 90, 547, 58, 735, 210, 466, 2294}}, | |
| // "How's it going?" | |
| /*decode_tokens=*/{ | |
| {224}, {24}, {8}, {66}, {246}, {18}, {2295}, {2294}})); | |
| ASSERT_OK_AND_ASSIGN( | |
| std::shared_ptr<ExecutionManager> execution_manager, | |
| ExecutionManager::Create(tokenizer_.get(), model_resources_.get(), | |
| std::move(executor), | |
| /*vision_executor_settings=*/nullptr, | |
| /*audio_executor_settings=*/nullptr, | |
| /*litert_env=*/nullptr)); | |
| ASSERT_OK_AND_ASSIGN( | |
| auto session, | |
| SessionAdvanced::Create(execution_manager, tokenizer_.get(), | |
| session_config, /*benchmark_info=*/std::nullopt)); | |
| std::vector<InputData> inputs; | |
| inputs.emplace_back(InputText("Hello World!")); | |
| absl::Status status; | |
| TaskState task_state = TaskState::kUnknown; | |
| std::vector<std::string> texts; | |
| EXPECT_OK(session->RunPrefill(inputs)); | |
| ASSERT_OK_AND_ASSIGN(auto task_controller, | |
| session->RunDecodeAsync(CreateStreamingTestCallback( | |
| status, task_state, texts))); | |
| EXPECT_OK(task_controller->WaitUntilDone(absl::Seconds(10))); | |
| EXPECT_OK(status); | |
| EXPECT_EQ(task_state, TaskState::kDone); | |
| EXPECT_EQ(texts.size(), 7); | |
| EXPECT_THAT(texts, | |
| testing::ElementsAre(" How", "'", "s", " it", " go", "ing", "?")); | |
| } | |
| TEST_F(SessionAdvancedTest, RunPrefillAndDecodeAsyncWithExternalSampler) { | |
| const std::vector<std::vector<int>> stop_token_ids = {{2294}}; | |
| SessionConfig session_config = SessionConfig::CreateDefault(); | |
| session_config.GetMutableSamplerParams() = sampler_params_; | |
| session_config.GetMutableStopTokenIds() = stop_token_ids; | |
| session_config.SetStartTokenId(2); | |
| // CPU backend will use internal sampler. | |
| session_config.SetSamplerBackend(Backend::CPU); | |
| ASSERT_OK_AND_ASSIGN( | |
| auto executor, | |
| CreateFakeLlmExecutor( | |
| // "Hello World!" | |
| /*prefill_tokens=*/{{2, 90, 547, 58, 735, 210, 466, 2294}}, | |
| // "How's it going?" | |
| /*decode_tokens=*/{ | |
| {224}, {24}, {8}, {66}, {246}, {18}, {2295}, {2294}})); | |
| ASSERT_OK_AND_ASSIGN( | |
| std::shared_ptr<ExecutionManager> execution_manager, | |
| ExecutionManager::Create(tokenizer_.get(), model_resources_.get(), | |
| std::move(executor), | |
| /*vision_executor_settings=*/nullptr, | |
| /*audio_executor_settings=*/nullptr, | |
| /*litert_env=*/nullptr)); | |
| ASSERT_OK_AND_ASSIGN( | |
| auto session, | |
| SessionAdvanced::Create(execution_manager, tokenizer_.get(), | |
| session_config, /*benchmark_info=*/std::nullopt)); | |
| std::vector<InputData> inputs; | |
| inputs.emplace_back(InputText("Hello World!")); | |
| absl::Status status; | |
| TaskState task_state = TaskState::kUnknown; | |
| std::vector<std::string> texts; | |
| EXPECT_OK(session->RunPrefill(inputs)); | |
| ASSERT_OK_AND_ASSIGN(auto task_controller, | |
| session->RunDecodeAsync(CreateStreamingTestCallback( | |
| status, task_state, texts))); | |
| EXPECT_OK(task_controller->WaitUntilDone(absl::Seconds(10))); | |
| EXPECT_OK(status); | |
| EXPECT_EQ(task_state, TaskState::kDone); | |
| EXPECT_EQ(texts.size(), 7); | |
| EXPECT_THAT(texts, | |
| testing::ElementsAre(" How", "'", "s", " it", " go", "ing", "?")); | |
| } | |
| TEST_F(SessionAdvancedTest, GenerateContentStream) { | |
| const std::vector<std::vector<int>> stop_token_ids = {{2294}}; | |
| SessionConfig session_config = SessionConfig::CreateDefault(); | |
| session_config.GetMutableSamplerParams() = sampler_params_; | |
| session_config.GetMutableStopTokenIds() = stop_token_ids; | |
| session_config.SetStartTokenId(2); | |
| ASSERT_OK_AND_ASSIGN( | |
| auto executor, | |
| CreateFakeLlmExecutor( | |
| // "Hello World!" | |
| /*prefill_tokens=*/{{2, 90, 547, 58, 735, 210, 466, 2294}}, | |
| // "How's it going?" | |
| /*decode_tokens=*/{ | |
| {224}, {24}, {8}, {66}, {246}, {18}, {2295}, {2294}})); | |
| ASSERT_OK_AND_ASSIGN( | |
| std::shared_ptr<ExecutionManager> execution_manager, | |
| ExecutionManager::Create(tokenizer_.get(), model_resources_.get(), | |
| std::move(executor), | |
| /*vision_executor_settings=*/nullptr, | |
| /*audio_executor_settings=*/nullptr, | |
| /*litert_env=*/nullptr)); | |
| ASSERT_OK_AND_ASSIGN( | |
| auto session, | |
| SessionAdvanced::Create(execution_manager, tokenizer_.get(), | |
| session_config, /*benchmark_info=*/std::nullopt)); | |
| std::vector<InputData> inputs; | |
| inputs.emplace_back(InputText("Hello World!")); | |
| absl::Status status; | |
| TaskState task_state; | |
| std::vector<std::string> texts; | |
| EXPECT_OK(session->GenerateContentStream( | |
| inputs, CreateStreamingTestCallback(status, task_state, texts))); | |
| EXPECT_OK(session->WaitUntilDone()); | |
| EXPECT_OK(status); | |
| EXPECT_EQ(task_state, TaskState::kDone); | |
| EXPECT_EQ(texts.size(), 7); | |
| EXPECT_THAT(texts, | |
| testing::ElementsAre(" How", "'", "s", " it", " go", "ing", "?")); | |
| } | |
| TEST_F(SessionAdvancedTest, RunPrefillEmptyInput) { | |
| const std::vector<std::vector<int>> stop_token_ids = {{2294}}; | |
| SessionConfig session_config = SessionConfig::CreateDefault(); | |
| session_config.GetMutableSamplerParams() = sampler_params_; | |
| session_config.GetMutableStopTokenIds() = stop_token_ids; | |
| session_config.SetSamplerBackend(Backend::CPU); | |
| ASSERT_OK_AND_ASSIGN( | |
| auto executor, | |
| CreateFakeLlmExecutor( | |
| // "Hello World!" | |
| /*prefill_tokens=*/{{2, 90, 547, 58, 735, 210, 466, 2294}}, | |
| // "How's it going?" | |
| /*decode_tokens=*/{ | |
| {224}, {24}, {8}, {66}, {246}, {18}, {2295}, {2294}})); | |
| ASSERT_OK_AND_ASSIGN( | |
| std::shared_ptr<ExecutionManager> execution_manager, | |
| ExecutionManager::Create(tokenizer_.get(), model_resources_.get(), | |
| std::move(executor), | |
| /*vision_executor_settings=*/nullptr, | |
| /*audio_executor_settings=*/nullptr, | |
| /*litert_env=*/nullptr)); | |
| ASSERT_OK_AND_ASSIGN( | |
| auto session, | |
| SessionAdvanced::Create(execution_manager, tokenizer_.get(), | |
| session_config, /*benchmark_info=*/std::nullopt)); | |
| std::vector<InputData> inputs; | |
| EXPECT_THAT(session->RunPrefill(inputs), | |
| StatusIs(absl::StatusCode::kInvalidArgument, | |
| "No token IDs found in preprocessed_contents.")); | |
| } | |
| TEST_F(SessionAdvancedTest, RunPrefillAsyncFailed) { | |
| // Configure the executor to fail at prefill. | |
| ASSERT_OK_AND_ASSIGN( | |
| auto executor, | |
| CreateFakeLlmExecutor( | |
| // "Hello World!" | |
| /*prefill_tokens=*/{{2, 90, 547, 58, 735, 210, 466, 2294}}, | |
| // "How's it going?" | |
| /*decode_tokens=*/{ | |
| {224}, {24}, {8}, {66}, {246}, {18}, {2295}, {2294}})); | |
| auto* fake_executor = static_cast<FakeLlmExecutor*>(executor.get()); | |
| fake_executor->SetPrefillStatus(absl::InternalError("Prefill failed")); | |
| const std::vector<std::vector<int>> stop_token_ids = {{2294}}; | |
| SessionConfig session_config = SessionConfig::CreateDefault(); | |
| session_config.GetMutableSamplerParams() = sampler_params_; | |
| session_config.GetMutableStopTokenIds() = stop_token_ids; | |
| session_config.SetStartTokenId(2); | |
| session_config.SetSamplerBackend(Backend::CPU); | |
| ASSERT_OK_AND_ASSIGN( | |
| std::shared_ptr<ExecutionManager> execution_manager, | |
| ExecutionManager::Create(tokenizer_.get(), model_resources_.get(), | |
| std::move(executor), | |
| /*vision_executor_settings=*/nullptr, | |
| /*audio_executor_settings=*/nullptr, | |
| /*litert_env=*/nullptr)); | |
| ASSERT_OK_AND_ASSIGN( | |
| auto session, | |
| SessionAdvanced::Create(execution_manager, tokenizer_.get(), | |
| session_config, /*benchmark_info=*/std::nullopt)); | |
| std::vector<InputData> inputs; | |
| inputs.emplace_back(InputText("Hello World!")); | |
| absl::Status status; | |
| TaskState task_state = TaskState::kUnknown; | |
| std::vector<std::string> texts; | |
| EXPECT_OK(session->RunPrefillAsync( | |
| inputs, CreateStreamingTestCallback(status, task_state, texts))); | |
| EXPECT_OK(execution_manager->WaitUntilAllDone(absl::Seconds(10))); | |
| EXPECT_FALSE(status.ok()); | |
| EXPECT_EQ(task_state, TaskState::kProcessing); | |
| EXPECT_THAT(status, StatusIs(absl::StatusCode::kInternal, "Prefill failed")); | |
| } | |
| TEST_F(SessionAdvancedTest, RunDecodeAsyncFailed) { | |
| // Configure the executor to fail at decode. | |
| ASSERT_OK_AND_ASSIGN( | |
| auto executor, | |
| CreateFakeLlmExecutor( | |
| // "Hello World!" | |
| /*prefill_tokens=*/{{2, 90, 547, 58, 735, 210, 466, 2294}}, | |
| // "How's it going?" | |
| /*decode_tokens=*/{ | |
| {224}, {24}, {8}, {66}, {246}, {18}, {2295}, {2294}})); | |
| auto* fake_executor = static_cast<FakeLlmExecutor*>(executor.get()); | |
| fake_executor->SetDecodeStatus(absl::InternalError("Decode failed")); | |
| const std::vector<std::vector<int>> stop_token_ids = {{2294}}; | |
| SessionConfig session_config = SessionConfig::CreateDefault(); | |
| session_config.GetMutableSamplerParams() = sampler_params_; | |
| session_config.GetMutableStopTokenIds() = stop_token_ids; | |
| session_config.SetStartTokenId(2); | |
| session_config.SetSamplerBackend(Backend::CPU); | |
| ASSERT_OK_AND_ASSIGN( | |
| std::shared_ptr<ExecutionManager> execution_manager, | |
| ExecutionManager::Create(tokenizer_.get(), model_resources_.get(), | |
| std::move(executor), | |
| /*vision_executor_settings=*/nullptr, | |
| /*audio_executor_settings=*/nullptr, | |
| /*litert_env=*/nullptr)); | |
| ASSERT_OK_AND_ASSIGN( | |
| auto session, | |
| SessionAdvanced::Create(execution_manager, tokenizer_.get(), | |
| session_config, /*benchmark_info=*/std::nullopt)); | |
| std::vector<InputData> inputs; | |
| inputs.emplace_back(InputText("Hello World!")); | |
| absl::Status status; | |
| TaskState task_state = TaskState::kUnknown; | |
| std::vector<std::string> texts; | |
| EXPECT_OK(session->RunPrefill(inputs)); | |
| ASSERT_OK_AND_ASSIGN(auto task_controller, | |
| session->RunDecodeAsync(CreateStreamingTestCallback( | |
| status, task_state, texts))); | |
| EXPECT_OK(task_controller->WaitUntilDone(absl::Seconds(10))); | |
| EXPECT_FALSE(status.ok()); | |
| EXPECT_EQ(task_state, TaskState::kProcessing); | |
| EXPECT_THAT(status, StatusIs(absl::StatusCode::kInternal, "Decode failed")); | |
| } | |
| TEST_F(SessionAdvancedTest, RunDecodeAsyncWithCancellationWithInternalSampler) { | |
| // Configure the executor to have a delay to simulate a long-running task. | |
| ASSERT_OK_AND_ASSIGN( | |
| auto fake_executor, | |
| CreateFakeLlmExecutor( | |
| // "Hello World!" | |
| /*prefill_tokens=*/{{2, 90, 547, 58, 735, 210, 466, 2294}}, | |
| // "How's it going?" | |
| /*decode_tokens=*/{ | |
| {224}, {24}, {8}, {66}, {246}, {18}, {2295}, {2294}})); | |
| fake_executor->SetDecodeDelay(absl::Milliseconds(200)); | |
| const std::vector<std::vector<int>> stop_token_ids = {{2294}}; | |
| SessionConfig session_config = SessionConfig::CreateDefault(); | |
| session_config.GetMutableSamplerParams() = sampler_params_; | |
| session_config.GetMutableStopTokenIds() = stop_token_ids; | |
| session_config.SetStartTokenId(2); | |
| ASSERT_OK_AND_ASSIGN( | |
| std::shared_ptr<ExecutionManager> execution_manager, | |
| ExecutionManager::Create(tokenizer_.get(), model_resources_.get(), | |
| std::move(fake_executor), | |
| /*vision_executor_settings=*/nullptr, | |
| /*audio_executor_settings=*/nullptr, | |
| /*litert_env=*/nullptr)); | |
| ASSERT_OK_AND_ASSIGN( | |
| auto session, | |
| SessionAdvanced::Create(execution_manager, tokenizer_.get(), | |
| session_config, /*benchmark_info=*/std::nullopt)); | |
| std::vector<InputData> inputs; | |
| inputs.emplace_back(InputText("Hello World!")); | |
| absl::Status status; | |
| TaskState task_state = TaskState::kUnknown; | |
| std::vector<std::string> responses; | |
| EXPECT_OK(session->RunPrefill(inputs)); | |
| ASSERT_OK_AND_ASSIGN(auto task_controller, | |
| session->RunDecodeAsync(CreateStreamingTestCallback( | |
| status, task_state, responses, | |
| /*delay_on_next=*/true))); | |
| // Wait for a short time to ensure the decoding has started. | |
| absl::SleepFor(absl::Milliseconds(100)); | |
| // Cancel the process. | |
| session->CancelProcess(); | |
| // Wait for the callback to be done. | |
| EXPECT_OK(task_controller->WaitUntilDone(absl::Seconds(10))); | |
| EXPECT_OK(status); | |
| EXPECT_EQ(task_state, TaskState::kCancelled); | |
| } | |
| TEST_F(SessionAdvancedTest, RunDecodeAsyncWithCancellationWithExternalSampler) { | |
| // Configure the executor to have a delay to simulate a long-running task. | |
| ASSERT_OK_AND_ASSIGN( | |
| auto fake_executor, | |
| CreateFakeLlmExecutor( | |
| // "Hello World!" | |
| /*prefill_tokens=*/{{2, 90, 547, 58, 735, 210, 466, 2294}}, | |
| // "How's it going?" | |
| /*decode_tokens=*/{ | |
| {224}, {24}, {8}, {66}, {246}, {18}, {2295}, {2294}})); | |
| fake_executor->SetDecodeDelay(absl::Milliseconds(200)); | |
| const std::vector<std::vector<int>> stop_token_ids = {{2294}}; | |
| SessionConfig session_config = SessionConfig::CreateDefault(); | |
| session_config.GetMutableSamplerParams() = sampler_params_; | |
| session_config.GetMutableStopTokenIds() = stop_token_ids; | |
| session_config.SetStartTokenId(2); | |
| session_config.SetUseExternalSampler(true); | |
| session_config.SetSamplerBackend(Backend::CPU); | |
| ASSERT_OK_AND_ASSIGN( | |
| std::shared_ptr<ExecutionManager> execution_manager, | |
| ExecutionManager::Create(tokenizer_.get(), model_resources_.get(), | |
| std::move(fake_executor), | |
| /*vision_executor_settings=*/nullptr, | |
| /*audio_executor_settings=*/nullptr, | |
| /*litert_env=*/nullptr)); | |
| ASSERT_OK_AND_ASSIGN( | |
| auto session, | |
| SessionAdvanced::Create(execution_manager, tokenizer_.get(), | |
| session_config, /*benchmark_info=*/std::nullopt)); | |
| std::vector<InputData> inputs; | |
| inputs.emplace_back(InputText("Hello World!")); | |
| absl::Status status; | |
| TaskState task_state = TaskState::kUnknown; | |
| std::vector<std::string> responses; | |
| EXPECT_OK(session->RunPrefill(inputs)); | |
| ASSERT_OK_AND_ASSIGN(auto task_controller, | |
| session->RunDecodeAsync(CreateStreamingTestCallback( | |
| status, task_state, responses, | |
| /*delay_on_next=*/true))); | |
| // Wait for a short time to ensure the decoding has started. | |
| absl::SleepFor(absl::Milliseconds(100)); | |
| // Cancel the process. | |
| session->CancelProcess(); | |
| // Wait for the callback to be done. | |
| EXPECT_OK(task_controller->WaitUntilDone(absl::Seconds(10))); | |
| EXPECT_OK(status); | |
| EXPECT_EQ(task_state, TaskState::kCancelled); | |
| } | |
| TEST_F(SessionAdvancedTest, | |
| RunDecodeAsyncWithTaskCancellationWithInternalSampler) { | |
| // Configure the executor to have a delay to simulate a long-running task. | |
| ASSERT_OK_AND_ASSIGN( | |
| auto fake_executor, | |
| CreateFakeLlmExecutor( | |
| // "Hello World!" | |
| /*prefill_tokens=*/{{2, 90, 547, 58, 735, 210, 466, 2294}}, | |
| // "How's it going?" | |
| /*decode_tokens=*/{ | |
| {224}, {24}, {8}, {66}, {246}, {18}, {2295}, {2294}})); | |
| fake_executor->SetDecodeDelay(absl::Milliseconds(200)); | |
| const std::vector<std::vector<int>> stop_token_ids = {{2294}}; | |
| SessionConfig session_config = SessionConfig::CreateDefault(); | |
| session_config.GetMutableSamplerParams() = sampler_params_; | |
| session_config.GetMutableStopTokenIds() = stop_token_ids; | |
| session_config.SetStartTokenId(2); | |
| ASSERT_OK_AND_ASSIGN( | |
| std::shared_ptr<ExecutionManager> execution_manager, | |
| ExecutionManager::Create(tokenizer_.get(), model_resources_.get(), | |
| std::move(fake_executor), | |
| /*vision_executor_settings=*/nullptr, | |
| /*audio_executor_settings=*/nullptr, | |
| /*litert_env=*/nullptr)); | |
| ASSERT_OK_AND_ASSIGN( | |
| auto session, | |
| SessionAdvanced::Create(execution_manager, tokenizer_.get(), | |
| session_config, /*benchmark_info=*/std::nullopt)); | |
| std::vector<InputData> inputs; | |
| inputs.emplace_back(InputText("Hello World!")); | |
| absl::Status status; | |
| TaskState task_state = TaskState::kUnknown; | |
| std::vector<std::string> responses; | |
| EXPECT_OK(session->RunPrefill(inputs)); | |
| ASSERT_OK_AND_ASSIGN( | |
| auto task_controller, | |
| session->RunDecodeAsync(CreateStreamingTestCallback( | |
| status, task_state, responses, /*delay_on_next=*/true))); | |
| // Wait for a short time to ensure the decoding has started. | |
| absl::SleepFor(absl::Milliseconds(100)); | |
| // Cancel the task. | |
| EXPECT_OK(task_controller->Cancel()); | |
| // Wait for the callback to be done. | |
| EXPECT_OK(task_controller->WaitUntilDone(absl::Seconds(10))); | |
| EXPECT_OK(status); | |
| EXPECT_EQ(task_state, TaskState::kCancelled); | |
| } | |
| TEST_F(SessionAdvancedTest, | |
| RunDecodeAsyncWithTaskCancellationWithExternalSampler) { | |
| // Configure the executor to have a delay to simulate a long-running task. | |
| ASSERT_OK_AND_ASSIGN( | |
| auto fake_executor, | |
| CreateFakeLlmExecutor( | |
| // "Hello World!" | |
| /*prefill_tokens=*/{{2, 90, 547, 58, 735, 210, 466, 2294}}, | |
| // "How's it going?" | |
| /*decode_tokens=*/{ | |
| {224}, {24}, {8}, {66}, {246}, {18}, {2295}, {2294}})); | |
| fake_executor->SetDecodeDelay(absl::Milliseconds(200)); | |
| const std::vector<std::vector<int>> stop_token_ids = {{2294}}; | |
| SessionConfig session_config = SessionConfig::CreateDefault(); | |
| session_config.GetMutableSamplerParams() = sampler_params_; | |
| session_config.GetMutableStopTokenIds() = stop_token_ids; | |
| session_config.SetStartTokenId(2); | |
| session_config.SetUseExternalSampler(true); | |
| session_config.SetSamplerBackend(Backend::CPU); | |
| ASSERT_OK_AND_ASSIGN( | |
| std::shared_ptr<ExecutionManager> execution_manager, | |
| ExecutionManager::Create(tokenizer_.get(), model_resources_.get(), | |
| std::move(fake_executor), | |
| /*vision_executor_settings=*/nullptr, | |
| /*audio_executor_settings=*/nullptr, | |
| /*litert_env=*/nullptr)); | |
| ASSERT_OK_AND_ASSIGN( | |
| auto session, | |
| SessionAdvanced::Create(execution_manager, tokenizer_.get(), | |
| session_config, /*benchmark_info=*/std::nullopt)); | |
| std::vector<InputData> inputs; | |
| inputs.emplace_back(InputText("Hello World!")); | |
| absl::Status status; | |
| TaskState task_state = TaskState::kUnknown; | |
| std::vector<std::string> responses; | |
| EXPECT_OK(session->RunPrefill(inputs)); | |
| ASSERT_OK_AND_ASSIGN( | |
| auto task_controller, | |
| session->RunDecodeAsync(CreateStreamingTestCallback( | |
| status, task_state, responses, /*delay_on_next=*/true))); | |
| // Wait for a short time to ensure the decoding has started. | |
| absl::SleepFor(absl::Milliseconds(100)); | |
| // Cancel the task. | |
| EXPECT_OK(task_controller->Cancel()); | |
| // Wait for the callback to be done. | |
| EXPECT_OK(task_controller->WaitUntilDone(absl::Seconds(10))); | |
| EXPECT_OK(status); | |
| EXPECT_EQ(task_state, TaskState::kCancelled); | |
| } | |
| class SessionAdvancedCancellationTest : public testing::TestWithParam<bool> { | |
| protected: | |
| void SetUp() override { | |
| auto tokenizer = ExtendedTokenizer::CreateFromFile( | |
| (std::filesystem::path(::testing::SrcDir()) / | |
| std::string(kTestdataDir) / "sentencepiece.model") | |
| .string()); | |
| ASSERT_OK(tokenizer); | |
| tokenizer.value()->SetExtendedToken(256000, "<start_of_audio>"); | |
| tokenizer_ = std::move(*tokenizer); | |
| model_resources_ = std::unique_ptr<ModelResources>(); | |
| sampler_params_.set_type(proto::SamplerParameters::TYPE_UNSPECIFIED); | |
| } | |
| bool use_benchmark_info_ = GetParam(); | |
| std::unique_ptr<Tokenizer> tokenizer_; | |
| std::unique_ptr<ModelResources> model_resources_; | |
| proto::SamplerParameters sampler_params_; | |
| }; | |
| TEST_P(SessionAdvancedCancellationTest, | |
| RunDecodeAsyncCancelThenGenerateWithBenchmarkWithInternalSamplerFailed) { | |
| // Configure the executor to have a delay to simulate a long-running task. | |
| ASSERT_OK_AND_ASSIGN( | |
| auto fake_executor, | |
| CreateFakeLlmExecutor( | |
| // "Hello World!" | |
| /*prefill_tokens=*/{{2, 90, 547, 58, 735, 210, 466, 2294}, | |
| // The second prefill doesn't have bos token. | |
| {90, 547, 58, 735, 210, 466, 2294}}, | |
| // "How's it going?" | |
| /*decode_tokens=*/{ | |
| {224}, {24}, {8}, {66}, {246}, {18}, {2295}, {2294}})); | |
| fake_executor->SetDecodeDelay(absl::Milliseconds(200)); | |
| const std::vector<std::vector<int>> stop_token_ids = {{2294}}; | |
| SessionConfig session_config = SessionConfig::CreateDefault(); | |
| session_config.GetMutableSamplerParams() = sampler_params_; | |
| session_config.GetMutableStopTokenIds() = stop_token_ids; | |
| session_config.SetStartTokenId(2); | |
| std::optional<BenchmarkInfo> benchmark_info; | |
| if (use_benchmark_info_) { | |
| proto::BenchmarkParams benchmark_params; | |
| benchmark_info.emplace(benchmark_params); | |
| } | |
| ASSERT_OK_AND_ASSIGN( | |
| std::shared_ptr<ExecutionManager> execution_manager, | |
| ExecutionManager::Create(tokenizer_.get(), model_resources_.get(), | |
| std::move(fake_executor), | |
| /*vision_executor_settings=*/nullptr, | |
| /*audio_executor_settings=*/nullptr, | |
| /*litert_env=*/nullptr)); | |
| ASSERT_OK_AND_ASSIGN( | |
| auto session, SessionAdvanced::Create(execution_manager, tokenizer_.get(), | |
| session_config, benchmark_info)); | |
| std::vector<InputData> inputs; | |
| inputs.emplace_back(InputText("Hello World!")); | |
| absl::Status status; | |
| TaskState task_state = TaskState::kUnknown; | |
| std::vector<std::string> responses; | |
| EXPECT_OK(session->RunPrefill(inputs)); | |
| ASSERT_OK_AND_ASSIGN( | |
| auto task_controller, | |
| session->RunDecodeAsync(CreateStreamingTestCallback( | |
| status, task_state, responses, /*delay_on_next=*/true))); | |
| // Cancel the process. | |
| session->CancelProcess(); | |
| // Wait for the callback to be done. | |
| EXPECT_OK(task_controller->WaitUntilDone(absl::Seconds(10))); | |
| EXPECT_OK(status); | |
| EXPECT_EQ(task_state, TaskState::kCancelled); | |
| // Generate again after cancellation. | |
| // The second generation should succeed. | |
| status = absl::OkStatus(); | |
| responses.clear(); | |
| EXPECT_OK(session->RunPrefill(inputs)); | |
| ASSERT_OK_AND_ASSIGN( | |
| task_controller, | |
| session->RunDecodeAsync(CreateStreamingTestCallback( | |
| status, task_state, responses, /*delay_on_next=*/true))); | |
| EXPECT_OK(task_controller->WaitUntilDone(absl::Seconds(10))); | |
| EXPECT_OK(status); | |
| EXPECT_EQ(task_state, TaskState::kDependentTaskCancelled); | |
| } | |
| TEST_P(SessionAdvancedCancellationTest, | |
| RunDecodeAsyncCancelThenGenerateWithBenchmarkWithExternalSamplerFailed) { | |
| // Configure the executor to have a delay to simulate a long-running task. | |
| ASSERT_OK_AND_ASSIGN( | |
| auto fake_executor, | |
| CreateFakeLlmExecutor( | |
| // "Hello World!" | |
| /*prefill_tokens=*/{{2, 90, 547, 58, 735, 210, 466, 2294}, | |
| // The second prefill doesn't have bos token. | |
| {90, 547, 58, 735, 210, 466, 2294}}, | |
| // "How's it going?" | |
| /*decode_tokens=*/{ | |
| {224}, {24}, {8}, {66}, {246}, {18}, {2295}, {2294}})); | |
| fake_executor->SetDecodeDelay(absl::Milliseconds(200)); | |
| const std::vector<std::vector<int>> stop_token_ids = {{2294}}; | |
| SessionConfig session_config = SessionConfig::CreateDefault(); | |
| session_config.GetMutableSamplerParams() = sampler_params_; | |
| session_config.GetMutableStopTokenIds() = stop_token_ids; | |
| session_config.SetStartTokenId(2); | |
| session_config.SetUseExternalSampler(true); | |
| session_config.SetSamplerBackend(Backend::CPU); | |
| std::optional<BenchmarkInfo> benchmark_info; | |
| if (use_benchmark_info_) { | |
| proto::BenchmarkParams benchmark_params; | |
| benchmark_info.emplace(benchmark_params); | |
| } | |
| ASSERT_OK_AND_ASSIGN( | |
| std::shared_ptr<ExecutionManager> execution_manager, | |
| ExecutionManager::Create(tokenizer_.get(), model_resources_.get(), | |
| std::move(fake_executor), | |
| /*vision_executor_settings=*/nullptr, | |
| /*audio_executor_settings=*/nullptr, | |
| /*litert_env=*/nullptr)); | |
| ASSERT_OK_AND_ASSIGN( | |
| auto session, SessionAdvanced::Create(execution_manager, tokenizer_.get(), | |
| session_config, benchmark_info)); | |
| std::vector<InputData> inputs; | |
| inputs.emplace_back(InputText("Hello World!")); | |
| absl::Status status; | |
| TaskState task_state = TaskState::kUnknown; | |
| std::vector<std::string> responses; | |
| EXPECT_OK(session->RunPrefill(inputs)); | |
| ASSERT_OK_AND_ASSIGN( | |
| auto task_controller, | |
| session->RunDecodeAsync(CreateStreamingTestCallback( | |
| status, task_state, responses, /*delay_on_next=*/true))); | |
| // Cancel the process. | |
| session->CancelProcess(); | |
| // Wait for the callback to be done. | |
| EXPECT_OK(task_controller->WaitUntilDone(absl::Seconds(10))); | |
| EXPECT_OK(status); | |
| EXPECT_EQ(task_state, TaskState::kCancelled); | |
| // Generate again after cancellation. | |
| // The second generation should succeed. | |
| status = absl::OkStatus(); | |
| responses.clear(); | |
| EXPECT_OK(session->RunPrefill(inputs)); | |
| ASSERT_OK_AND_ASSIGN( | |
| task_controller, | |
| session->RunDecodeAsync(CreateStreamingTestCallback( | |
| status, task_state, responses, /*delay_on_next=*/true))); | |
| EXPECT_OK(task_controller->WaitUntilDone(absl::Seconds(10))); | |
| EXPECT_OK(status); | |
| EXPECT_EQ(task_state, TaskState::kDependentTaskCancelled); | |
| } | |
| INSTANTIATE_TEST_SUITE_P(SessionAdvancedCancellationTest, | |
| SessionAdvancedCancellationTest, testing::Bool(), | |
| testing::PrintToStringParamName()); | |
| TEST_F(SessionAdvancedTest, RunPrefillAsyncOnCancelledSession) { | |
| ASSERT_OK_AND_ASSIGN( | |
| auto fake_executor, | |
| CreateFakeLlmExecutor( | |
| // "Hello World!" | |
| /*prefill_tokens=*/{{2, 90, 547, 58, 735, 210, 466, 2294}}, | |
| // "How's it going?" | |
| /*decode_tokens=*/{ | |
| {224}, {24}, {8}, {66}, {246}, {18}, {2295}, {2294}})); | |
| const std::vector<std::vector<int>> stop_token_ids = {{2294}}; | |
| SessionConfig session_config = SessionConfig::CreateDefault(); | |
| session_config.GetMutableSamplerParams() = sampler_params_; | |
| session_config.GetMutableStopTokenIds() = stop_token_ids; | |
| session_config.SetStartTokenId(2); | |
| session_config.SetSamplerBackend(Backend::CPU); | |
| ASSERT_OK_AND_ASSIGN( | |
| std::shared_ptr<ExecutionManager> execution_manager, | |
| ExecutionManager::Create(tokenizer_.get(), model_resources_.get(), | |
| std::move(fake_executor), | |
| /*vision_executor_settings=*/nullptr, | |
| /*audio_executor_settings=*/nullptr, | |
| /*litert_env=*/nullptr)); | |
| ASSERT_OK_AND_ASSIGN( | |
| auto session, | |
| SessionAdvanced::Create(execution_manager, tokenizer_.get(), | |
| session_config, /*benchmark_info=*/std::nullopt)); | |
| session->CancelProcess(); | |
| std::vector<InputData> inputs; | |
| inputs.emplace_back(InputText("Hello World!")); | |
| absl::Status status; | |
| TaskState task_state; | |
| std::vector<std::string> responses; | |
| // The session is cancelled, so the call should return with a kCancelled | |
| // error. | |
| EXPECT_OK(session->RunPrefillAsync( | |
| inputs, CreateStreamingTestCallback(status, task_state, responses))); | |
| // Wait for the callback to be done. | |
| EXPECT_OK(execution_manager->WaitUntilAllDone(absl::Seconds(10))); | |
| EXPECT_OK(status); | |
| EXPECT_EQ(task_state, TaskState::kDone); | |
| } | |
| TEST_F(SessionAdvancedTest, | |
| TestBenchmarkModeWithoutNumPrefillTokensRespectPromptTemplate) { | |
| const std::vector<std::vector<int>> stop_token_ids = {{2294}}; | |
| SessionConfig session_config = SessionConfig::CreateDefault(); | |
| session_config.GetMutableSamplerParams() = sampler_params_; | |
| session_config.GetMutableStopTokenIds() = stop_token_ids; | |
| session_config.SetStartTokenId(2); | |
| session_config.SetSamplerBackend(Backend::CPU); | |
| session_config.GetMutablePromptTemplates().mutable_user()->set_prefix( | |
| "<test>User\n"); | |
| session_config.GetMutablePromptTemplates().mutable_user()->set_suffix( | |
| "<end>\n"); | |
| session_config.GetMutablePromptTemplates().mutable_model()->set_prefix( | |
| "<test>Model\n"); | |
| ASSERT_OK_AND_ASSIGN( | |
| auto executor, | |
| CreateFakeLlmExecutor( | |
| // Expected tokens: "</s><test>User\nHello World!" + | |
| // "<end>\n<test>Model\n" | |
| /*prefill_tokens=*/{{2, 4, 0, 39, 637, 0, 3328, 8, 179, 90, 547, 58, | |
| 735, 210, 466, 2294}, | |
| {0, 40, 23, 0, 4, 0, 39, 637, 0, 197, 979, 3076}}, | |
| /*decode_tokens=*/{{224}})); | |
| proto::BenchmarkParams benchmark_params; | |
| BenchmarkInfo benchmark_info(benchmark_params); | |
| ASSERT_OK_AND_ASSIGN( | |
| std::shared_ptr<ExecutionManager> execution_manager, | |
| ExecutionManager::Create(tokenizer_.get(), model_resources_.get(), | |
| std::move(executor), | |
| /*vision_executor_settings=*/nullptr, | |
| /*audio_executor_settings=*/nullptr, | |
| /*litert_env=*/nullptr)); | |
| ASSERT_OK_AND_ASSIGN( | |
| auto session, SessionAdvanced::Create(execution_manager, tokenizer_.get(), | |
| session_config, benchmark_info)); | |
| std::vector<InputData> inputs; | |
| inputs.emplace_back(InputText("Hello World!")); | |
| EXPECT_OK(session->RunPrefill(inputs)); | |
| EXPECT_EQ(session->GetBenchmarkInfo()->GetTotalPrefillTurns(), 1); | |
| } | |
| TEST_F(SessionAdvancedTest, | |
| TestBenchmarkModeWithNumPrefillTokensIgnorePromptTemplate) { | |
| const std::vector<std::vector<int>> stop_token_ids = {{2294}}; | |
| SessionConfig session_config = SessionConfig::CreateDefault(); | |
| session_config.GetMutableSamplerParams() = sampler_params_; | |
| session_config.GetMutableStopTokenIds() = stop_token_ids; | |
| session_config.SetStartTokenId(2); | |
| session_config.SetSamplerBackend(Backend::CPU); | |
| session_config.GetMutablePromptTemplates().mutable_user()->set_prefix( | |
| "<test>User\n"); | |
| session_config.GetMutablePromptTemplates().mutable_user()->set_suffix( | |
| "<end>\n"); | |
| session_config.GetMutablePromptTemplates().mutable_model()->set_prefix( | |
| "<test>Model\n"); | |
| ASSERT_OK_AND_ASSIGN( | |
| auto executor, | |
| CreateFakeLlmExecutor( | |
| // Expected tokens: "Hello World!" (No templates) | |
| /*prefill_tokens=*/{{90, 547, 58, 735, 210, 466, 2294}}, | |
| /*decode_tokens=*/{{224}})); | |
| proto::BenchmarkParams benchmark_params; | |
| benchmark_params.set_num_prefill_tokens(7); | |
| BenchmarkInfo benchmark_info(benchmark_params); | |
| ASSERT_OK_AND_ASSIGN( | |
| std::shared_ptr<ExecutionManager> execution_manager, | |
| ExecutionManager::Create(tokenizer_.get(), model_resources_.get(), | |
| std::move(executor), | |
| /*vision_executor_settings=*/nullptr, | |
| /*audio_executor_settings=*/nullptr, | |
| /*litert_env=*/nullptr)); | |
| ASSERT_OK_AND_ASSIGN( | |
| auto session, SessionAdvanced::Create(execution_manager, tokenizer_.get(), | |
| session_config, benchmark_info)); | |
| std::vector<InputData> inputs; | |
| inputs.emplace_back(InputText("Hello World!")); | |
| EXPECT_OK(session->RunPrefill(inputs)); | |
| EXPECT_EQ(session->GetBenchmarkInfo()->GetTotalPrefillTurns(), 1); | |
| } | |
| TEST_F(SessionAdvancedTest, | |
| PrefillAndDecodeWithConstrainedDecodingWithInternalSampler) { | |
| // Fake constraint that expects "'s it". | |
| std::vector<int> expected_token_ids = {24, 8, 66, 0}; | |
| auto constraint = | |
| FakeConstraint(expected_token_ids, /*vocabulary_size=*/2560); | |
| const std::vector<std::vector<int>> stop_token_ids = {{2294}, {0}}; | |
| // Top P sampler. | |
| proto::SamplerParameters sampler_params; | |
| sampler_params.set_type(proto::SamplerParameters::TOP_P); | |
| sampler_params.set_k(1); | |
| sampler_params.set_temperature(1.0); | |
| sampler_params.set_p(0.5); | |
| sampler_params.set_seed(1); | |
| SessionConfig session_config = SessionConfig::CreateDefault(); | |
| session_config.GetMutableSamplerParams() = sampler_params; | |
| session_config.GetMutableStopTokenIds() = stop_token_ids; | |
| session_config.SetStartTokenId(2); | |
| ASSERT_OK_AND_ASSIGN( | |
| auto executor, | |
| CreateFakeLlmExecutor( | |
| /*prefill_tokens=*/{{2, 224}, // The first prefill. | |
| {0}}, // The expected prefill tokens that after | |
| // stop tokens are found in decoding with | |
| // sampler. That is, the last | |
| // sampled tokens at stop condition. | |
| // "How's it going?" | |
| /*decode_tokens=*/{{24}, {8}, {66}, {246}, {18}, {2295}, {2294}})); | |
| ASSERT_OK_AND_ASSIGN( | |
| std::shared_ptr<ExecutionManager> execution_manager, | |
| ExecutionManager::Create(tokenizer_.get(), model_resources_.get(), | |
| std::move(executor), | |
| /*vision_executor_settings=*/nullptr, | |
| /*audio_executor_settings=*/nullptr, | |
| /*litert_env=*/nullptr)); | |
| auto session = | |
| SessionAdvanced::Create(execution_manager, tokenizer_.get(), | |
| session_config, /*benchmark_info=*/std::nullopt); | |
| std::vector<InputData> inputs; | |
| inputs.emplace_back(InputText("How")); | |
| absl::Status status; | |
| TaskState task_state; | |
| std::vector<std::string> texts; | |
| auto decode_config = DecodeConfig::CreateDefault(); | |
| decode_config.SetConstraint(&constraint); | |
| EXPECT_OK((*session)->RunPrefill(inputs)); | |
| ASSERT_OK_AND_ASSIGN(auto task_controller, (*session)->RunDecodeAsync( | |
| CreateStreamingTestCallback( | |
| status, task_state, texts), | |
| decode_config)); | |
| EXPECT_OK(task_controller->WaitUntilDone(absl::Seconds(10))); | |
| EXPECT_OK(status); | |
| EXPECT_EQ(task_state, TaskState::kDone); | |
| EXPECT_EQ(texts.size(), 3); | |
| EXPECT_THAT(texts, testing::ElementsAre("'", "s", " it")); | |
| } | |
| TEST_F(SessionAdvancedTest, | |
| PrefillAndDecodeWithConstrainedDecodingWithExternalSampler) { | |
| // Fake constraint that expects "'s it". | |
| std::vector<int> expected_token_ids = {24, 8, 66, 0}; | |
| auto constraint = | |
| FakeConstraint(expected_token_ids, /*vocabulary_size=*/2560); | |
| const std::vector<std::vector<int>> stop_token_ids = {{2294}, {0}}; | |
| // Top P sampler. | |
| proto::SamplerParameters sampler_params; | |
| sampler_params.set_type(proto::SamplerParameters::TOP_P); | |
| sampler_params.set_k(1); | |
| sampler_params.set_temperature(1.0); | |
| sampler_params.set_p(0.5); | |
| sampler_params.set_seed(1); | |
| SessionConfig session_config = SessionConfig::CreateDefault(); | |
| session_config.GetMutableSamplerParams() = sampler_params; | |
| session_config.GetMutableStopTokenIds() = stop_token_ids; | |
| session_config.SetStartTokenId(2); | |
| session_config.SetUseExternalSampler(true); | |
| session_config.SetSamplerBackend(Backend::CPU); | |
| ASSERT_OK_AND_ASSIGN( | |
| auto executor, | |
| CreateFakeLlmExecutor( | |
| /*prefill_tokens=*/{{2, 224}, // The first prefill. | |
| {0}}, // The expected prefill tokens that after | |
| // stop tokens are found in decoding with | |
| // sampler. That is, the last | |
| // sampled tokens at stop condition. | |
| // "How's it going?" | |
| /*decode_tokens=*/{{24}, {8}, {66}, {246}, {18}, {2295}, {2294}})); | |
| ASSERT_OK_AND_ASSIGN( | |
| std::shared_ptr<ExecutionManager> execution_manager, | |
| ExecutionManager::Create(tokenizer_.get(), model_resources_.get(), | |
| std::move(executor), | |
| /*vision_executor_settings=*/nullptr, | |
| /*audio_executor_settings=*/nullptr, | |
| /*litert_env=*/nullptr)); | |
| auto session = | |
| SessionAdvanced::Create(execution_manager, tokenizer_.get(), | |
| session_config, /*benchmark_info=*/std::nullopt); | |
| std::vector<InputData> inputs; | |
| inputs.emplace_back(InputText("How")); | |
| absl::Status status; | |
| TaskState task_state; | |
| std::vector<std::string> texts; | |
| auto decode_config = DecodeConfig::CreateDefault(); | |
| decode_config.SetConstraint(&constraint); | |
| EXPECT_OK((*session)->RunPrefill(inputs)); | |
| ASSERT_OK_AND_ASSIGN(auto task_controller, (*session)->RunDecodeAsync( | |
| CreateStreamingTestCallback( | |
| status, task_state, texts), | |
| decode_config)); | |
| EXPECT_OK(task_controller->WaitUntilDone(absl::Seconds(10))); | |
| EXPECT_OK(status); | |
| EXPECT_EQ(task_state, TaskState::kDone); | |
| EXPECT_EQ(texts.size(), 3); | |
| EXPECT_THAT(texts, testing::ElementsAre("'", "s", " it")); | |
| } | |
| TEST_F(SessionAdvancedTest, RunIncrementalPrefillWithDecode) { | |
| const std::vector<std::vector<int>> stop_token_ids = {{2294}}; | |
| SessionConfig session_config = SessionConfig::CreateDefault(); | |
| session_config.SetStartTokenId(2); | |
| session_config.SetSamplerBackend(Backend::CPU); | |
| session_config.GetMutableSamplerParams() = sampler_params_; | |
| session_config.GetMutableStopTokenIds() = stop_token_ids; | |
| session_config.GetMutablePromptTemplates().mutable_user()->set_prefix( | |
| "User:"); | |
| session_config.GetMutablePromptTemplates().mutable_user()->set_suffix( | |
| "[END]"); | |
| session_config.GetMutablePromptTemplates().mutable_model()->set_prefix( | |
| "Model:"); | |
| session_config.GetMutableLlmModelType().mutable_gemma3n(); | |
| ASSERT_OK_AND_ASSIGN( | |
| auto executor, | |
| CreateFakeLlmExecutor( | |
| /*prefill_tokens=*/ | |
| { | |
| {2, 423, 8, 179, 29, 207, 19, 547, 58}, // prefill chunk 1.1 | |
| {735, 210, 466, 2294}, // prefill chunk 1.2 | |
| {433, 2172, 1920, 432, 197, 979, 3076, | |
| 29}, // prefill ran before decode with turn change template | |
| {423, 8, 179, 29, 207, 19, 547, 58, 735, 210, 466, | |
| 2294}, // prefill chunk 2.1 | |
| {433, 2172, 1920, 432, 197, 979, 3076, | |
| 29}, // prefill ran before decode with turn change template | |
| }, | |
| /*decode_tokens=*/ | |
| {{1}, {2}, {3}, {2294}, {1}, {2}, {3}, {2294}})); | |
| ASSERT_OK_AND_ASSIGN( | |
| std::shared_ptr<ExecutionManager> execution_manager, | |
| ExecutionManager::Create(tokenizer_.get(), model_resources_.get(), | |
| std::move(executor), | |
| /*vision_executor_settings=*/nullptr, | |
| /*audio_executor_settings=*/nullptr, | |
| /*litert_env=*/nullptr)); | |
| ASSERT_OK_AND_ASSIGN( | |
| auto session, | |
| SessionAdvanced::Create(execution_manager, tokenizer_.get(), | |
| session_config, /*benchmark_info=*/std::nullopt)); | |
| { | |
| std::vector<InputData> inputs; | |
| inputs.emplace_back(InputText("Hello ")); | |
| EXPECT_OK(session->RunPrefill(inputs)); | |
| } | |
| { | |
| std::vector<InputData> inputs; | |
| inputs.emplace_back(InputText("World!")); | |
| EXPECT_OK(session->RunPrefill(inputs)); | |
| } | |
| { | |
| EXPECT_OK(session->RunDecode()); | |
| } | |
| { | |
| std::vector<InputData> inputs; | |
| inputs.emplace_back(InputText("Hello World!")); | |
| EXPECT_OK(session->RunPrefill(inputs)); | |
| } | |
| { | |
| EXPECT_OK(session->RunDecode()); | |
| } | |
| } | |
| TEST_F(SessionAdvancedTest, ProcessAndCombineContentsTextAndAudioSuccess) { | |
| const std::vector<std::vector<int>> stop_token_ids = {{2294}}; | |
| SessionConfig session_config = SessionConfig::CreateDefault(); | |
| session_config.SetAudioModalityEnabled(true); | |
| session_config.SetStartTokenId(2); | |
| session_config.SetSamplerBackend(Backend::CPU); | |
| session_config.GetMutableSamplerParams() = sampler_params_; | |
| session_config.GetMutableStopTokenIds() = stop_token_ids; | |
| session_config.GetMutablePromptTemplates().mutable_user()->set_prefix( | |
| "User:"); | |
| session_config.GetMutablePromptTemplates().mutable_user()->set_suffix( | |
| "[END]"); | |
| session_config.GetMutablePromptTemplates().mutable_model()->set_prefix( | |
| "Model:"); | |
| session_config.GetMutableLlmModelType().mutable_gemma3n(); | |
| ASSERT_OK_AND_ASSIGN( | |
| auto audio_executor_settings, | |
| CreateAudioExecutorSettings((std::filesystem::path(::testing::SrcDir()) / | |
| std::string(kTestAudioModelPath)) | |
| .string(), | |
| /*max_sequence_length=*/0, Backend::CPU)); | |
| ASSERT_OK_AND_ASSIGN( | |
| auto executor, | |
| CreateFakeLlmExecutor( | |
| // "User:Hello World!<start_of_audio>[END]Model:" | |
| /*prefill_tokens=*/{{2, 423, 8, 179, 29, 207, 19, | |
| 547, 58, 735, 210, 466, 2294, 256000, | |
| -2, -2, -2, -2, -2, -4}, | |
| {433, 2172, 1920, 432, 197, 979, 3076, 29}}, | |
| // "How's it going?" | |
| /*decode_tokens=*/ | |
| {{224}, {24}, {8}, {66}, {246}, {18}, {2295}, {2294}}, | |
| /*audio_embedding=*/ | |
| std::vector<float>(kExpectedAudioEmbedding.begin(), | |
| kExpectedAudioEmbedding.end()))); | |
| LITERT_ASSERT_OK_AND_ASSIGN( | |
| auto env, Environment::Create(std::vector<Environment::Option>())); | |
| ASSERT_OK_AND_ASSIGN( | |
| std::shared_ptr<ExecutionManager> execution_manager, | |
| ExecutionManager::Create( | |
| tokenizer_.get(), model_resources_.get(), std::move(executor), | |
| /*vision_executor_settings=*/nullptr, | |
| /*audio_executor_settings=*/std::move(audio_executor_settings), | |
| /*litert_env=*/&env)); | |
| ASSERT_OK_AND_ASSIGN( | |
| auto session, | |
| SessionAdvanced::Create(execution_manager, tokenizer_.get(), | |
| session_config, /*benchmark_info=*/std::nullopt)); | |
| std::vector<InputData> inputs; | |
| inputs.emplace_back(InputText("Hello World!<start_of_audio>")); | |
| LITERT_ASSERT_OK_AND_ASSIGN( | |
| TensorBuffer mel_spectrogram_data, | |
| CopyToTensorBuffer<float>( | |
| mel_spectrogram_data, | |
| {1, kSpectrogramSequenceLength, kSpectrogramFrequencySlots})); | |
| InputAudio input_audio(std::move(mel_spectrogram_data)); | |
| inputs.emplace_back(std::move(input_audio)); | |
| inputs.emplace_back(InputAudioEnd()); | |
| EXPECT_OK(session->RunPrefill(inputs)); | |
| } | |
| TEST_F(SessionAdvancedTest, ProcessAndCombineContentsTextAudioTextSuccess) { | |
| const std::vector<std::vector<int>> stop_token_ids = {{2294}}; | |
| SessionConfig session_config = SessionConfig::CreateDefault(); | |
| session_config.SetAudioModalityEnabled(true); | |
| session_config.SetStartTokenId(2); | |
| session_config.SetSamplerBackend(Backend::CPU); | |
| session_config.GetMutableSamplerParams() = sampler_params_; | |
| session_config.GetMutableStopTokenIds() = stop_token_ids; | |
| session_config.GetMutablePromptTemplates().mutable_user()->set_prefix( | |
| "User:"); | |
| session_config.GetMutablePromptTemplates().mutable_user()->set_suffix( | |
| "[END]"); | |
| session_config.GetMutablePromptTemplates().mutable_model()->set_prefix( | |
| "Model:"); | |
| session_config.GetMutableLlmModelType().mutable_gemma3n(); | |
| ASSERT_OK_AND_ASSIGN( | |
| auto audio_executor_settings, | |
| CreateAudioExecutorSettings((std::filesystem::path(::testing::SrcDir()) / | |
| std::string(kTestAudioModelPath)) | |
| .string(), | |
| /*max_sequence_length=*/0, Backend::CPU)); | |
| ASSERT_OK_AND_ASSIGN( | |
| auto executor, | |
| CreateFakeLlmExecutor( | |
| // "User:Hello World!<start_of_audio>What does the audio say?" | |
| // "[END]Model:" | |
| /*prefill_tokens=*/ | |
| {{2, 423, 8, 179, 29, 207, 19, 547, 58, 735, 210, | |
| 466, 2294, 256000, -2, -2, -2, -2, -2, -4, 583, 378, | |
| 844, 166, 3, 14, 1252, 54, 58, 626, 2295}, | |
| {3995, 2172, 1920, 432, 197, 979, 3076, 29}}, | |
| // "How's it going?" | |
| /*decode_tokens=*/ | |
| {{224}, {24}, {8}, {66}, {246}, {18}, {2295}, {2294}}, | |
| /*audio_embedding=*/ | |
| std::vector<float>(kExpectedAudioEmbedding.begin(), | |
| kExpectedAudioEmbedding.end()))); | |
| LITERT_ASSERT_OK_AND_ASSIGN( | |
| auto env, Environment::Create(std::vector<Environment::Option>())); | |
| ASSERT_OK_AND_ASSIGN( | |
| std::shared_ptr<ExecutionManager> execution_manager, | |
| ExecutionManager::Create( | |
| tokenizer_.get(), model_resources_.get(), std::move(executor), | |
| /*vision_executor_settings=*/nullptr, | |
| /*audio_executor_settings=*/std::move(audio_executor_settings), | |
| /*litert_env=*/&env)); | |
| ASSERT_OK_AND_ASSIGN( | |
| auto session, | |
| SessionAdvanced::Create(execution_manager, tokenizer_.get(), | |
| session_config, /*benchmark_info=*/std::nullopt)); | |
| std::vector<InputData> inputs; | |
| inputs.emplace_back(InputText("Hello World!<start_of_audio>")); | |
| LITERT_ASSERT_OK_AND_ASSIGN( | |
| TensorBuffer mel_spectrogram_data, | |
| CopyToTensorBuffer<float>( | |
| mel_spectrogram_data, | |
| {1, kSpectrogramSequenceLength, kSpectrogramFrequencySlots})); | |
| InputAudio input_audio(std::move(mel_spectrogram_data)); | |
| inputs.emplace_back(std::move(input_audio)); | |
| inputs.emplace_back(InputAudioEnd()); | |
| inputs.emplace_back(InputText("What does the audio say?")); | |
| EXPECT_OK(session->RunPrefill(inputs)); | |
| } | |
| TEST_F(SessionAdvancedTest, RunTextScoringEmptyTargetTextFailure) { | |
| ASSERT_OK_AND_ASSIGN(auto session, CreateTestSession()); | |
| std::vector<absl::string_view> target_text; | |
| EXPECT_THAT(session->RunTextScoring(target_text, | |
| /*store_token_lengths=*/false), | |
| StatusIs(absl::StatusCode::kInvalidArgument, | |
| "Target text size should be 1.")); | |
| } | |
| TEST_F(SessionAdvancedTest, RunTextScoringMultipleTargetTextFailure) { | |
| ASSERT_OK_AND_ASSIGN(auto session, CreateTestSession()); | |
| std::vector<absl::string_view> target_text; | |
| target_text.push_back("How's it going?"); | |
| target_text.push_back("How are you?"); | |
| EXPECT_THAT( | |
| session->RunTextScoring(target_text, /*store_token_lengths=*/false), | |
| StatusIs(absl::StatusCode::kInvalidArgument, | |
| "Target text size should be 1.")); | |
| } | |
| TEST_F(SessionAdvancedTest, RunTextScoringWithoutTokenLengthsSuccess) { | |
| ASSERT_OK_AND_ASSIGN(auto session, CreateTestSession()); | |
| std::vector<InputData> inputs; | |
| inputs.emplace_back(InputText("Hello World!")); | |
| EXPECT_OK(session->RunPrefill(inputs)); | |
| std::vector<absl::string_view> target_texts; | |
| target_texts.push_back("How's it going?"); | |
| const auto responses = session->RunTextScoring(target_texts, | |
| /*store_token_lengths=*/false); | |
| EXPECT_OK(responses); | |
| // Expect a single output candidate with score 0.0f. | |
| EXPECT_EQ(responses->GetScores().size(), 1); | |
| EXPECT_EQ(responses->GetScores()[0], 0.0f); | |
| EXPECT_FALSE(responses->GetTokenLengths().has_value()); | |
| } | |
| TEST_F(SessionAdvancedTest, RunTextScoringWithTokenLengthsSuccess) { | |
| ASSERT_OK_AND_ASSIGN(auto session, CreateTestSession()); | |
| std::vector<InputData> inputs; | |
| inputs.emplace_back(InputText("Hello World!")); | |
| EXPECT_OK(session->RunPrefill(inputs)); | |
| std::vector<absl::string_view> target_texts; | |
| target_texts.push_back("How's it going?"); | |
| const auto responses = session->RunTextScoring(target_texts, | |
| /*store_token_lengths=*/true); | |
| EXPECT_OK(responses); | |
| // Expect a single output candidate with score 0.0f and token length 7. | |
| EXPECT_EQ(responses->GetScores().size(), 1); | |
| EXPECT_EQ(responses->GetScores()[0], 0.0f); | |
| EXPECT_TRUE(responses->GetTokenLengths().has_value()); | |
| EXPECT_EQ(responses->GetTokenLengths()->size(), 1); | |
| EXPECT_EQ((*responses->GetTokenLengths())[0], 7); | |
| } | |
| TEST_F(SessionAdvancedTest, RunTextScoringAsyncEmptyTargetTextFailure) { | |
| ASSERT_OK_AND_ASSIGN(auto session, CreateTestSession()); | |
| std::vector<absl::string_view> target_text; | |
| auto controller = session->RunTextScoringAsync( | |
| target_text, [](absl::StatusOr<Responses> r) {}, | |
| /*store_token_lengths=*/false); | |
| EXPECT_THAT(controller.status(), StatusIs(absl::StatusCode::kInvalidArgument, | |
| "Target text size should be 1.")); | |
| } | |
| TEST_F(SessionAdvancedTest, RunTextScoringAsyncMultipleTargetTextFailure) { | |
| ASSERT_OK_AND_ASSIGN(auto session, CreateTestSession()); | |
| std::vector<absl::string_view> target_text; | |
| target_text.push_back("How's it going?"); | |
| target_text.push_back("How are you?"); | |
| auto controller = session->RunTextScoringAsync( | |
| target_text, [](absl::StatusOr<Responses> r) {}, | |
| /*store_token_lengths=*/false); | |
| EXPECT_THAT(controller.status(), StatusIs(absl::StatusCode::kInvalidArgument, | |
| "Target text size should be 1.")); | |
| } | |
| TEST_F(SessionAdvancedTest, RunTextScoringAsyncWithoutTokenLengthsSuccess) { | |
| ASSERT_OK_AND_ASSIGN(auto session, CreateTestSession()); | |
| std::vector<InputData> inputs; | |
| inputs.emplace_back(InputText("Hello World!")); | |
| EXPECT_OK(session->RunPrefill(inputs)); | |
| std::vector<absl::string_view> target_texts; | |
| target_texts.push_back("How's it going?"); | |
| absl::Status status; | |
| std::optional<Responses> responses; | |
| ASSERT_OK_AND_ASSIGN(auto controller, | |
| session->RunTextScoringAsync( | |
| target_texts, | |
| [&](absl::StatusOr<Responses> r) { | |
| if (!r.ok()) { | |
| status = r.status(); | |
| return; | |
| } | |
| if (IsTaskEndState(r->GetTaskState())) { | |
| responses.emplace(*std::move(r)); | |
| } | |
| }, | |
| /*store_token_lengths=*/false)); | |
| EXPECT_OK(controller->WaitUntilDone(absl::Seconds(10))); | |
| EXPECT_OK(status); | |
| ASSERT_TRUE(responses.has_value()); | |
| // Expect a single output candidate with score 0.0f. | |
| EXPECT_EQ(responses->GetScores().size(), 1); | |
| EXPECT_EQ(responses->GetScores()[0], 0.0f); | |
| EXPECT_FALSE(responses->GetTokenLengths().has_value()); | |
| } | |
| TEST_F(SessionAdvancedTest, RunTextScoringAsyncWithTokenLengthsSuccess) { | |
| ASSERT_OK_AND_ASSIGN(auto session, CreateTestSession()); | |
| std::vector<InputData> inputs; | |
| inputs.emplace_back(InputText("Hello World!")); | |
| EXPECT_OK(session->RunPrefill(inputs)); | |
| std::vector<absl::string_view> target_texts; | |
| target_texts.push_back("How's it going?"); | |
| absl::Status status; | |
| std::optional<Responses> responses; | |
| ASSERT_OK_AND_ASSIGN(auto controller, | |
| session->RunTextScoringAsync( | |
| target_texts, | |
| [&](absl::StatusOr<Responses> r) { | |
| if (!r.ok()) { | |
| status = r.status(); | |
| return; | |
| } | |
| if (IsTaskEndState(r->GetTaskState())) { | |
| responses.emplace(*std::move(r)); | |
| } | |
| }, | |
| /*store_token_lengths=*/true)); | |
| EXPECT_OK(controller->WaitUntilDone(absl::Seconds(10))); | |
| EXPECT_OK(status); | |
| ASSERT_TRUE(responses.has_value()); | |
| // Expect a single output candidate with score 0.0f and token length 7. | |
| EXPECT_EQ(responses->GetScores().size(), 1); | |
| EXPECT_EQ(responses->GetScores()[0], 0.0f); | |
| EXPECT_TRUE(responses->GetTokenLengths().has_value()); | |
| EXPECT_EQ(responses->GetTokenLengths()->size(), 1); | |
| EXPECT_EQ((*responses->GetTokenLengths())[0], 7); | |
| } | |
| } // namespace | |
| } // namespace litert::lm | |