Spaces:
Running
Running
| // Copyright 2025 The ODML Authors. | |
| // | |
| // Licensed under the Apache License, Version 2.0 (the "License"); | |
| // you may not use this file except in compliance with the License. | |
| // You may obtain a copy of the License at | |
| // | |
| // http://www.apache.org/licenses/LICENSE-2.0 | |
| // | |
| // Unless required by applicable law or agreed to in writing, software | |
| // distributed under the License is distributed on an "AS IS" BASIS, | |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| // See the License for the specific language governing permissions and | |
| // limitations under the License. | |
| namespace litert::lm { | |
| namespace { | |
| using ::testing::AllOf; | |
| using ::testing::ElementsAre; | |
| using ::testing::HasSubstr; | |
| using ::testing::Not; | |
| using ::testing::Return; | |
| using ::testing::VariantWith; | |
| absl::string_view kTestLlmPath = | |
| "litert_lm/runtime/testdata/test_lm.litertlm"; | |
| constexpr char kTestTokenizerPath[] = | |
| "litert_lm/runtime/components/testdata/gemma3_sentencepiece.model"; | |
| constexpr char kGemma3ToolsMultiPrefillTemplatePath[] = | |
| "litert_lm/runtime/components/testdata/" | |
| "google-gemma-3n-e2b-it-tools-multi-prefill.jinja"; | |
| constexpr char kGemma3TemplatePath[] = | |
| "litert_lm/runtime/components/testdata/google-gemma-3-1b-it.jinja"; | |
| constexpr absl::string_view kTestJinjaPromptTemplate = R"jinja( | |
| {%- for message in messages -%} | |
| {{- '<start_of_turn>' + message.role + '\n' -}} | |
| {%- if message.content is string -%} | |
| {{- message.content + '<end_of_turn>\n' -}} | |
| {%- else -%} | |
| {{- message.content[0].text + '<end_of_turn>\n' -}} | |
| {%- endif -%} | |
| {%- endfor -%} | |
| )jinja"; | |
| std::string GetTestdataPath(absl::string_view file_path) { | |
| return absl::StrCat(::testing::SrcDir(), "/", file_path); | |
| } | |
| std::string ReadFile(absl::string_view path) { | |
| std::ifstream ifstr(std::string(path), std::ios::binary); | |
| std::stringstream contents; | |
| contents << ifstr.rdbuf(); | |
| return contents.str(); | |
| } | |
| class MockSession : public Engine::Session { | |
| public: | |
| MOCK_METHOD(absl::StatusOr<Responses>, GenerateContent, | |
| (const std::vector<InputData>& contents), (override)); | |
| MOCK_METHOD( | |
| absl::Status, GenerateContentStream, | |
| (const std::vector<InputData>& contents, | |
| absl::AnyInvocable<void(absl::StatusOr<Responses>)> user_callback), | |
| (override)); | |
| MOCK_METHOD( | |
| absl::Status, GenerateContentStream, | |
| (const std::vector<InputData>& contents, | |
| absl::AnyInvocable<void(absl::StatusOr<Responses>)> user_callback, | |
| const DecodeConfig& decode_config), | |
| (override)); | |
| MOCK_METHOD(absl::StatusOr<Responses>, RunTextScoring, | |
| (const std::vector<absl::string_view>& target_text, | |
| bool store_token_lengths), | |
| (override)); | |
| MOCK_METHOD(absl::StatusOr<std::unique_ptr<Engine::Session::TaskController>>, | |
| RunTextScoringAsync, | |
| (const std::vector<absl::string_view>& target_text, | |
| absl::AnyInvocable<void(absl::StatusOr<Responses>)> callback, | |
| bool store_token_lengths), | |
| (override)); | |
| MOCK_METHOD(absl::Status, RunPrefill, | |
| (const std::vector<InputData>& contents), (override)); | |
| MOCK_METHOD( | |
| absl::StatusOr<std::unique_ptr<Engine::Session::TaskController>>, | |
| RunPrefillAsync, | |
| (const std::vector<InputData>& contents, | |
| absl::AnyInvocable<void(absl::StatusOr<Responses>)> user_callback), | |
| (override)); | |
| MOCK_METHOD(absl::StatusOr<Responses>, RunDecode, (), (override)); | |
| MOCK_METHOD(absl::StatusOr<Responses>, RunDecode, | |
| (const DecodeConfig& decode_config), (override)); | |
| MOCK_METHOD( | |
| absl::StatusOr<std::unique_ptr<Engine::Session::TaskController>>, | |
| RunDecodeAsync, | |
| (absl::AnyInvocable<void(absl::StatusOr<Responses>)> user_callback), | |
| (override)); | |
| MOCK_METHOD( | |
| absl::StatusOr<std::unique_ptr<Engine::Session::TaskController>>, | |
| RunDecodeAsync, | |
| (absl::AnyInvocable<void(absl::StatusOr<Responses>)> user_callback, | |
| const DecodeConfig& decode_config), | |
| (override)); | |
| MOCK_METHOD(absl::StatusOr<std::unique_ptr<Session>>, Clone, (), (override)); | |
| MOCK_METHOD(absl::StatusOr<std::unique_ptr<Session>>, CloneAsync, | |
| (absl::AnyInvocable<void(absl::StatusOr<Responses>)> callback), | |
| (override)); | |
| MOCK_METHOD(absl::StatusOr<BenchmarkInfo>, GetBenchmarkInfo, (), (override)); | |
| MOCK_METHOD(absl::StatusOr<BenchmarkInfo*>, GetMutableBenchmarkInfo, (), | |
| (override)); | |
| MOCK_METHOD(void, CancelProcess, (), (override)); | |
| MOCK_METHOD(absl::Status, WaitUntilDone, (), (override)); | |
| MOCK_METHOD(absl::Status, SaveCheckpoint, (absl::string_view label), | |
| (override)); | |
| MOCK_METHOD(absl::Status, RewindToCheckpoint, (absl::string_view label), | |
| (override)); | |
| MOCK_METHOD(const SessionConfig&, GetSessionConfig, (), (const, override)); | |
| }; | |
| class MockEngine : public Engine { | |
| public: | |
| MOCK_METHOD(const EngineSettings&, GetEngineSettings, (), (const, override)); | |
| MOCK_METHOD(const Tokenizer&, GetTokenizer, (), (const, override)); | |
| MOCK_METHOD(absl::StatusOr<AudioExecutorProperties>, | |
| GetAudioExecutorProperties, (), (const, override)); | |
| MOCK_METHOD(absl::StatusOr<VisionExecutorProperties>, | |
| GetVisionExecutorProperties, (), (const, override)); | |
| MOCK_METHOD(absl::StatusOr<std::unique_ptr<Session>>, CreateSession, | |
| (const SessionConfig& session_config), (override)); | |
| MOCK_METHOD(absl::Status, WaitUntilDone, (absl::Duration timeout), | |
| (override)); | |
| }; | |
| class MockTaskController : public Engine::Session::TaskController { | |
| public: | |
| MockTaskController() = default; | |
| ~MockTaskController() override = default; | |
| MOCK_METHOD(absl::Status, Cancel, (), (override)); | |
| }; | |
| absl::AnyInvocable<void(absl::StatusOr<Message>)> CreateTestMessageCallback( | |
| Message& expected_message, absl::Notification& done) { | |
| return [&expected_message, &done](absl::StatusOr<Message> message) mutable { | |
| // If the message is not ok, fail the test. | |
| if (!message.ok()) { | |
| FAIL() << "Message user_callback failed: " << message.status(); | |
| return; | |
| } | |
| // If the message is null, the last callback is received. | |
| if (auto json_message = std::get_if<JsonMessage>(&message.value()); | |
| json_message->is_null()) { | |
| JsonMessage& expected_json_message = | |
| std::get<JsonMessage>(expected_message); | |
| ASSERT_TRUE(expected_json_message["content"][0]["text"].is_string()); | |
| std::string expected_string = expected_json_message["content"][0]["text"]; | |
| // The expected string should be empty after the last callback. | |
| EXPECT_TRUE(expected_string.empty()); | |
| done.Notify(); | |
| return; | |
| } | |
| // Otherwise, this is a partial response. | |
| if (auto json_message = std::get_if<JsonMessage>(&message.value())) { | |
| JsonMessage& expected_json_message = | |
| std::get<JsonMessage>(expected_message); | |
| // Compare the message text content by prefix, and update the expected | |
| // message to the remaining text for the next user_callback. | |
| ASSERT_TRUE(expected_json_message["content"][0]["text"].is_string()); | |
| ASSERT_TRUE((*json_message)["content"][0]["text"].is_string()); | |
| std::string expected_string = expected_json_message["content"][0]["text"]; | |
| std::string actual_string = (*json_message)["content"][0]["text"]; | |
| EXPECT_TRUE(absl::StartsWith(expected_string, actual_string)) | |
| << "Expected: " << expected_string << "\nActual: " << actual_string; | |
| expected_json_message["content"][0]["text"] = | |
| expected_string.substr(actual_string.size()); | |
| } | |
| }; | |
| } | |
| absl::AnyInvocable<void(absl::StatusOr<Message>)> | |
| CreateTestMultiMessageCallback(const std::vector<Message>& expected_messages, | |
| absl::Notification& done) { | |
| return [&expected_messages, &done, | |
| current_index = 0](absl::StatusOr<Message> message) mutable { | |
| ASSERT_OK(message); | |
| ASSERT_TRUE(std::holds_alternative<JsonMessage>(message.value())); | |
| auto json_message = std::get<JsonMessage>(message.value()); | |
| // If the message is null, the message stream is complete. | |
| if (json_message.is_null()) { | |
| EXPECT_TRUE(current_index == expected_messages.size()) | |
| << "Expected " << expected_messages.size() | |
| << " messages but only got " << current_index; | |
| done.Notify(); | |
| return; | |
| } | |
| ASSERT_LT(current_index, expected_messages.size()) | |
| << "Received more messages than expected. Expected size: " | |
| << expected_messages.size(); | |
| EXPECT_THAT(*message, testing::Eq(expected_messages[current_index])); | |
| ++current_index; | |
| }; | |
| } | |
| TEST(ConversationConfigTest, CreateDefault) { | |
| ASSERT_OK_AND_ASSIGN(auto model_assets, | |
| ModelAssets::Create(GetTestdataPath(kTestLlmPath))); | |
| ASSERT_OK_AND_ASSIGN(auto engine_settings, EngineSettings::CreateDefault( | |
| model_assets, Backend::CPU)); | |
| engine_settings.GetMutableMainExecutorSettings().SetCacheDir(":nocache"); | |
| engine_settings.GetMutableMainExecutorSettings().SetMaxNumTokens(10); | |
| ASSERT_OK_AND_ASSIGN(auto engine, EngineFactory::CreateAny(engine_settings)); | |
| ASSERT_OK_AND_ASSIGN(auto config, ConversationConfig::CreateDefault(*engine)); | |
| EXPECT_OK(Conversation::Create(*engine, config)); | |
| } | |
| TEST(ConversationConfigTest, CreateDefaultWithOverwritePromptTemplate) { | |
| ASSERT_OK_AND_ASSIGN(auto model_assets, | |
| ModelAssets::Create(GetTestdataPath(kTestLlmPath))); | |
| ASSERT_OK_AND_ASSIGN(auto engine_settings, EngineSettings::CreateDefault( | |
| model_assets, Backend::CPU)); | |
| engine_settings.GetMutableMainExecutorSettings().SetCacheDir(":nocache"); | |
| engine_settings.GetMutableMainExecutorSettings().SetMaxNumTokens(10); | |
| ASSERT_OK_AND_ASSIGN(auto engine, EngineFactory::CreateAny(engine_settings)); | |
| ASSERT_OK_AND_ASSIGN(auto config, ConversationConfig::Builder() | |
| .SetOverwritePromptTemplate( | |
| PromptTemplate("Hello world!")) | |
| .Build(*engine)); | |
| EXPECT_EQ(config.GetPromptTemplate().GetTemplateSource(), "Hello world!"); | |
| EXPECT_TRUE( | |
| config.GetSessionConfig().GetPromptTemplates().user().prefix().empty()); | |
| } | |
| TEST(ConversationConfigTest, CreateWithBuilder) { | |
| ASSERT_OK_AND_ASSIGN(auto model_assets, | |
| ModelAssets::Create(GetTestdataPath(kTestLlmPath))); | |
| ASSERT_OK_AND_ASSIGN(auto engine_settings, EngineSettings::CreateDefault( | |
| model_assets, Backend::CPU)); | |
| engine_settings.GetMutableMainExecutorSettings().SetCacheDir(":nocache"); | |
| engine_settings.GetMutableMainExecutorSettings().SetMaxNumTokens(10); | |
| ASSERT_OK_AND_ASSIGN(auto engine, EngineFactory::CreateAny(engine_settings)); | |
| auto session_config = SessionConfig::CreateDefault(); | |
| session_config.GetMutableLlmModelType().mutable_gemma3n(); | |
| ASSERT_OK_AND_ASSIGN( | |
| auto config, | |
| ConversationConfig::Builder() | |
| .SetSessionConfig(session_config) | |
| .SetPreface(JsonPreface{ | |
| .messages = {{{"role", "system"}, | |
| {"content", "You are a helpful assistant."}}}}) | |
| .Build(*engine)); | |
| EXPECT_TRUE(std::holds_alternative<JsonPreface>(config.GetPreface())); | |
| EXPECT_EQ( | |
| std::get<JsonPreface>(config.GetPreface()).messages, | |
| nlohmann::ordered_json( | |
| {{{"role", "system"}, {"content", "You are a helpful assistant."}}})); | |
| EXPECT_EQ(config.GetSessionConfig().GetLlmModelType().model_type_case(), | |
| proto::LlmModelType::kGemma3N); | |
| EXPECT_TRUE( | |
| config.GetSessionConfig().GetPromptTemplates().user().prefix().empty()); | |
| EXPECT_OK(Conversation::Create(*engine, config)); | |
| } | |
| TEST(ConversationConfigTest, FilterChannelContentFromKvCache) { | |
| ASSERT_OK_AND_ASSIGN(auto model_assets, | |
| ModelAssets::Create(GetTestdataPath(kTestLlmPath))); | |
| ASSERT_OK_AND_ASSIGN(auto engine_settings, EngineSettings::CreateDefault( | |
| model_assets, Backend::CPU)); | |
| engine_settings.GetMutableMainExecutorSettings().SetCacheDir(":nocache"); | |
| engine_settings.GetMutableMainExecutorSettings().SetMaxNumTokens(10); | |
| ASSERT_OK_AND_ASSIGN(auto engine, EngineFactory::CreateAny(engine_settings)); | |
| ASSERT_OK_AND_ASSIGN(auto config, | |
| ConversationConfig::Builder() | |
| .SetFilterChannelContentFromKvCache(true) | |
| .Build(*engine)); | |
| EXPECT_TRUE(config.filter_channel_content_from_kv_cache()); | |
| } | |
| TEST(ConversationConfigTest, OverwritePromptTemplate) { | |
| ASSERT_OK_AND_ASSIGN(auto model_assets, | |
| ModelAssets::Create(GetTestdataPath(kTestLlmPath))); | |
| ASSERT_OK_AND_ASSIGN(auto engine_settings, EngineSettings::CreateDefault( | |
| model_assets, Backend::CPU)); | |
| engine_settings.GetMutableMainExecutorSettings().SetCacheDir(":nocache"); | |
| engine_settings.GetMutableMainExecutorSettings().SetMaxNumTokens(10); | |
| ASSERT_OK_AND_ASSIGN(auto engine, EngineFactory::CreateAny(engine_settings)); | |
| ASSERT_OK_AND_ASSIGN( | |
| auto config, | |
| ConversationConfig::Builder() | |
| .SetOverwritePromptTemplate(PromptTemplate("overwrite template")) | |
| .Build(*engine)); | |
| EXPECT_EQ(config.GetPromptTemplate().GetTemplateSource(), | |
| "overwrite template"); | |
| } | |
| struct ConversationTestParams { | |
| bool enable_constrained_decoding; | |
| bool prefill_preface_on_init; | |
| }; | |
| class ConversationTest : public testing::TestWithParam<ConversationTestParams> { | |
| public: | |
| static std::vector<ConversationTestParams> GetTestParams() { | |
| std::vector<ConversationTestParams> params; | |
| for (bool enable_constrained_decoding : {true, false}) { | |
| for (bool prefill_preface_on_init : {true, false}) { | |
| params.push_back( | |
| {enable_constrained_decoding, prefill_preface_on_init}); | |
| } | |
| } | |
| return params; | |
| } | |
| protected: | |
| void SetUp() override { | |
| ASSERT_OK_AND_ASSIGN( | |
| tokenizer_, | |
| SentencePieceTokenizer::CreateFromFile( | |
| (std::filesystem::path(::testing::SrcDir()) / kTestTokenizerPath) | |
| .string())); | |
| model_assets_ = ModelAssets::Create(GetTestdataPath(kTestLlmPath)); | |
| ASSERT_OK(model_assets_); | |
| engine_settings_ = | |
| EngineSettings::CreateDefault(*model_assets_, Backend::CPU); | |
| ASSERT_OK(engine_settings_); | |
| session_config_ = SessionConfig::CreateDefault(); | |
| session_config_.SetStartTokenId(0); | |
| session_config_.GetMutableStopTokenIds().push_back({1}); | |
| *session_config_.GetMutableLlmModelType().mutable_gemma3() = {}; | |
| } | |
| std::unique_ptr<MockSession> CreateMockSession() { | |
| auto mock_session = std::make_unique<MockSession>(); | |
| EXPECT_CALL(*mock_session, GetSessionConfig()) | |
| .WillRepeatedly(testing::ReturnRef(session_config_)); | |
| return mock_session; | |
| } | |
| std::unique_ptr<MockEngine> CreateMockEngine( | |
| std::unique_ptr<MockSession> mock_session) { | |
| auto mock_engine = std::make_unique<MockEngine>(); | |
| EXPECT_CALL(*mock_engine, GetEngineSettings()) | |
| .WillRepeatedly(testing::ReturnRef(*engine_settings_)); | |
| EXPECT_CALL(*mock_engine, CreateSession(testing::_)) | |
| .WillOnce(testing::Return(std::move(mock_session))); | |
| EXPECT_CALL(*mock_engine, GetTokenizer()) | |
| .WillRepeatedly(testing::ReturnRef(*tokenizer_)); | |
| return mock_engine; | |
| } | |
| std::unique_ptr<Tokenizer> tokenizer_; | |
| absl::StatusOr<ModelAssets> model_assets_; | |
| absl::StatusOr<EngineSettings> engine_settings_; | |
| SessionConfig session_config_ = SessionConfig::CreateDefault(); | |
| bool enable_constrained_decoding_ = GetParam().enable_constrained_decoding; | |
| bool prefill_preface_on_init_ = GetParam().prefill_preface_on_init; | |
| }; | |
| TEST_P(ConversationTest, SendMessage) { | |
| ASSERT_OK_AND_ASSIGN(auto model_assets, | |
| ModelAssets::Create(GetTestdataPath(kTestLlmPath))); | |
| ASSERT_OK_AND_ASSIGN(auto engine_settings, EngineSettings::CreateDefault( | |
| model_assets, Backend::CPU)); | |
| engine_settings.GetMutableMainExecutorSettings().SetCacheDir(":nocache"); | |
| engine_settings.GetMutableMainExecutorSettings().SetMaxNumTokens(10); | |
| ASSERT_OK_AND_ASSIGN(auto engine, EngineFactory::CreateAny(engine_settings)); | |
| ASSERT_OK_AND_ASSIGN( | |
| auto config, | |
| ConversationConfig::Builder() | |
| .SetEnableConstrainedDecoding(enable_constrained_decoding_) | |
| .SetPrefillPrefaceOnInit(prefill_preface_on_init_) | |
| .Build(*engine)); | |
| ASSERT_OK_AND_ASSIGN(auto conversation, | |
| Conversation::Create(*engine, config)); | |
| EXPECT_THAT(conversation->GetHistory(), testing::IsEmpty()); | |
| JsonMessage user_message = {{"role", "user"}, {"content", "Hello world!"}}; | |
| ASSERT_OK_AND_ASSIGN(const Message message, | |
| conversation->SendMessage(user_message)); | |
| // The expected message is just some gibberish text, because the test LLM has | |
| // random weights. | |
| JsonMessage expected_message = { | |
| {"role", "assistant"}, | |
| {"content", | |
| {{{"type", "text"}, {"text", "TarefaByte دارایेत्र investigaciónప్రదేశ"}}}}}; | |
| const JsonMessage& json_message = std::get<JsonMessage>(message); | |
| EXPECT_EQ(json_message, expected_message); | |
| EXPECT_THAT(conversation->GetHistory(), | |
| testing::ElementsAre(user_message, expected_message)); | |
| } | |
| TEST_P(ConversationTest, SendMessageGemma3Template) { | |
| ASSERT_OK_AND_ASSIGN(auto model_assets, | |
| ModelAssets::Create(GetTestdataPath(kTestLlmPath))); | |
| ASSERT_OK_AND_ASSIGN(auto engine_settings, EngineSettings::CreateDefault( | |
| model_assets, Backend::CPU)); | |
| engine_settings.GetMutableMainExecutorSettings().SetCacheDir(":nocache"); | |
| engine_settings.GetMutableMainExecutorSettings().SetMaxNumTokens(20); | |
| ASSERT_OK_AND_ASSIGN(auto engine, EngineFactory::CreateAny(engine_settings)); | |
| std::string gemma3_prompt_template = | |
| ReadFile(GetTestdataPath(kGemma3TemplatePath)); | |
| ASSERT_OK_AND_ASSIGN( | |
| auto config, | |
| ConversationConfig::Builder() | |
| .SetEnableConstrainedDecoding(enable_constrained_decoding_) | |
| .SetPrefillPrefaceOnInit(prefill_preface_on_init_) | |
| .SetOverwritePromptTemplate(PromptTemplate(gemma3_prompt_template)) | |
| .Build(*engine)); | |
| ASSERT_OK_AND_ASSIGN(auto conversation, | |
| Conversation::Create(*engine, config)); | |
| EXPECT_THAT(conversation->GetHistory(), testing::IsEmpty()); | |
| JsonMessage user_message = {{"role", "user"}, {"content", "Hello world!"}}; | |
| EXPECT_OK(conversation->SendMessage(user_message)); | |
| } | |
| TEST_P(ConversationTest, SendSingleMessage) { | |
| // Set up mock Session. | |
| auto mock_session = CreateMockSession(); | |
| MockSession* mock_session_ptr = mock_session.get(); | |
| auto mock_engine = CreateMockEngine(std::move(mock_session)); | |
| // Create Conversation. | |
| ASSERT_OK_AND_ASSIGN( | |
| auto conversation_config, | |
| ConversationConfig::Builder() | |
| .SetSessionConfig(session_config_) | |
| .SetOverwritePromptTemplate(PromptTemplate(kTestJinjaPromptTemplate)) | |
| .Build(*mock_engine)); | |
| ASSERT_OK_AND_ASSIGN(auto conversation, | |
| Conversation::Create(*mock_engine, conversation_config)); | |
| // We will send a single message. | |
| JsonMessage user_message = {{"role", "user"}, {"content", "How are you?"}}; | |
| absl::string_view expected_input_text = | |
| "<start_of_turn>user\n" | |
| "How are you?<end_of_turn>\n"; | |
| EXPECT_CALL(*mock_session_ptr, | |
| RunPrefill(testing::ElementsAre( | |
| testing::VariantWith<InputText>(testing::Property( | |
| &InputText::GetRawTextString, expected_input_text))))) | |
| .WillOnce(testing::Return(absl::OkStatus())); | |
| EXPECT_CALL(*mock_session_ptr, RunDecode(testing::_)) | |
| .WillOnce( | |
| testing::Return(Responses(TaskState::kProcessing, {"I am good."}))); | |
| ASSERT_OK_AND_ASSIGN(const Message response, | |
| conversation->SendMessage(user_message)); | |
| JsonMessage assistant_message = nlohmann::ordered_json::parse(R"({ | |
| "role": "assistant", | |
| "content": [ | |
| { | |
| "type": "text", | |
| "text": "I am good." | |
| } | |
| ] | |
| })"); | |
| EXPECT_EQ(std::get<JsonMessage>(response), assistant_message); | |
| EXPECT_THAT(conversation->GetHistory(), | |
| testing::ElementsAre(user_message, assistant_message)); | |
| } | |
| TEST_P(ConversationTest, SendSingleMessageWithExtraContext) { | |
| // Set up mock Session. | |
| auto mock_session = CreateMockSession(); | |
| MockSession* mock_session_ptr = mock_session.get(); | |
| auto mock_engine = CreateMockEngine(std::move(mock_session)); | |
| // Create Conversation and overwrite prompt template. | |
| absl::string_view prompt_template = R"jinja( | |
| {%- if enable_thinking -%} | |
| <start_of_turn>system | |
| Thinking enabled.<end_of_turn> | |
| {% else %} | |
| <start_of_turn>system | |
| Thinking disabled.<end_of_turn> | |
| {%- endif -%} | |
| {%- for message in messages -%} | |
| {{- '<start_of_turn>' + message.role + '\n' -}} | |
| {%- if message.content is string -%} | |
| {{- message.content + '<end_of_turn>\n' -}} | |
| {%- else -%} | |
| {{- message.content[0].text + '<end_of_turn>\n' -}} | |
| {%- endif -%} | |
| {%- endfor -%} | |
| )jinja"; | |
| ASSERT_OK_AND_ASSIGN( | |
| auto conversation_config, | |
| ConversationConfig::Builder() | |
| .SetSessionConfig(session_config_) | |
| .SetOverwritePromptTemplate(PromptTemplate(prompt_template)) | |
| .Build(*mock_engine)); | |
| ASSERT_OK_AND_ASSIGN(auto conversation, | |
| Conversation::Create(*mock_engine, conversation_config)); | |
| // We will send a single message. | |
| JsonMessage user_message = {{"role", "user"}, {"content", "How are you?"}}; | |
| OptionalArgs optional_args; | |
| optional_args.extra_context = absl::flat_hash_map<std::string, std::string>{ | |
| {"enable_thinking", "true"}}; | |
| absl::string_view expected_input_text = | |
| "<start_of_turn>system\nThinking enabled.<end_of_turn>\n" | |
| "<start_of_turn>user\n" | |
| "How are you?<end_of_turn>\n"; | |
| EXPECT_CALL(*mock_session_ptr, | |
| RunPrefill(testing::ElementsAre( | |
| testing::VariantWith<InputText>(testing::Property( | |
| &InputText::GetRawTextString, expected_input_text))))) | |
| .WillOnce(testing::Return(absl::OkStatus())); | |
| EXPECT_CALL(*mock_session_ptr, RunDecode(testing::_)) | |
| .WillOnce( | |
| testing::Return(Responses(TaskState::kProcessing, {"I am good."}))); | |
| ASSERT_OK_AND_ASSIGN( | |
| const Message response, | |
| conversation->SendMessage(user_message, std::move(optional_args))); | |
| JsonMessage assistant_message = nlohmann::ordered_json::parse(R"({ | |
| "role": "assistant", | |
| "content": [ | |
| { | |
| "type": "text", | |
| "text": "I am good." | |
| } | |
| ] | |
| })"); | |
| EXPECT_EQ(std::get<JsonMessage>(response), assistant_message); | |
| EXPECT_THAT(conversation->GetHistory(), | |
| testing::ElementsAre(user_message, assistant_message)); | |
| } | |
| TEST_P(ConversationTest, SendSingleMessageWithExtraContextOverwritingPreface) { | |
| // Set up mock Session. | |
| auto mock_session = CreateMockSession(); | |
| MockSession* mock_session_ptr = mock_session.get(); | |
| auto mock_engine = CreateMockEngine(std::move(mock_session)); | |
| // Create Conversation and overwrite prompt template. | |
| absl::string_view prompt_template = R"jinja( | |
| {%- if key1 -%} | |
| Key1: {{ key1 + "\n"}} | |
| {%- endif -%} | |
| {%- if key2 -%} | |
| Key2: {{ key2 + "\n"}} | |
| {%- endif -%} | |
| {%- if key3 -%} | |
| Key3: {{ key3 + "\n"}} | |
| {%- endif -%} | |
| {%- for message in messages -%} | |
| {{- '<start_of_turn>' + message.role + '\n' -}} | |
| {%- if message.content is string -%} | |
| {{- message.content + '<end_of_turn>\n' -}} | |
| {%- else -%} | |
| {{- message.content[0].text + '<end_of_turn>\n' -}} | |
| {%- endif -%} | |
| {%- endfor -%} | |
| )jinja"; | |
| JsonPreface preface; | |
| // This extra context will be set at the Conversation level. | |
| preface.extra_context = {{"key1", "val1"}, {"key2", "val2"}}; | |
| ASSERT_OK_AND_ASSIGN( | |
| auto conversation_config, | |
| ConversationConfig::Builder() | |
| .SetSessionConfig(session_config_) | |
| .SetPreface(preface) | |
| .SetOverwritePromptTemplate(PromptTemplate(prompt_template)) | |
| .Build(*mock_engine)); | |
| ASSERT_OK_AND_ASSIGN(auto conversation, | |
| Conversation::Create(*mock_engine, conversation_config)); | |
| // We will send a single message with extra context that overwrites key1 and | |
| // adds key3. | |
| JsonMessage user_message = {{"role", "user"}, {"content", "How are you?"}}; | |
| OptionalArgs optional_args; | |
| optional_args.extra_context = | |
| nlohmann::ordered_json{{"key1", "val1_new"}, {"key3", "val3"}}; | |
| // key1 should be overwritten to val1_new. | |
| // key2 should remain val2. | |
| // key3 should be added as val3. | |
| absl::string_view expected_input_text = | |
| "Key1: val1_new\n" | |
| "Key2: val2\n" | |
| "Key3: val3\n" | |
| "<start_of_turn>user\n" | |
| "How are you?<end_of_turn>\n"; | |
| EXPECT_CALL(*mock_session_ptr, | |
| RunPrefill(testing::ElementsAre( | |
| testing::VariantWith<InputText>(testing::Property( | |
| &InputText::GetRawTextString, expected_input_text))))) | |
| .WillOnce(testing::Return(absl::OkStatus())); | |
| EXPECT_CALL(*mock_session_ptr, RunDecode(testing::_)) | |
| .WillOnce( | |
| testing::Return(Responses(TaskState::kProcessing, {"I am good."}))); | |
| ASSERT_OK_AND_ASSIGN( | |
| const Message response, | |
| conversation->SendMessage(user_message, std::move(optional_args))); | |
| JsonMessage assistant_message = nlohmann::ordered_json::parse(R"({ | |
| "role": "assistant", | |
| "content": [ | |
| { | |
| "type": "text", | |
| "text": "I am good." | |
| } | |
| ] | |
| })"); | |
| EXPECT_EQ(std::get<JsonMessage>(response), assistant_message); | |
| } | |
| TEST_P(ConversationTest, SendMultipleMessages) { | |
| // Set up mock Session. | |
| auto mock_session = CreateMockSession(); | |
| MockSession* mock_session_ptr = mock_session.get(); | |
| auto mock_engine = CreateMockEngine(std::move(mock_session)); | |
| // Create Conversation. | |
| ASSERT_OK_AND_ASSIGN( | |
| auto conversation_config, | |
| ConversationConfig::Builder() | |
| .SetSessionConfig(session_config_) | |
| .SetEnableConstrainedDecoding(enable_constrained_decoding_) | |
| .SetOverwritePromptTemplate(PromptTemplate(kTestJinjaPromptTemplate)) | |
| .SetPrefillPrefaceOnInit(prefill_preface_on_init_) | |
| .Build(*mock_engine)); | |
| ASSERT_OK_AND_ASSIGN(auto conversation, | |
| Conversation::Create(*mock_engine, conversation_config)); | |
| // We will send two consecutive messages. | |
| JsonMessage user_messages = nlohmann::ordered_json::parse(R"json( | |
| [ | |
| { | |
| "role": "user", | |
| "content": "Hello world!" | |
| }, | |
| { | |
| "role": "user", | |
| "content": "How are you?" | |
| } | |
| ] | |
| )json"); | |
| absl::string_view expected_input_text = | |
| "<start_of_turn>user\n" | |
| "Hello world!<end_of_turn>\n" | |
| "<start_of_turn>user\n" | |
| "How are you?<end_of_turn>\n"; | |
| EXPECT_CALL(*mock_session_ptr, | |
| RunPrefill(testing::ElementsAre( | |
| testing::VariantWith<InputText>(testing::Property( | |
| &InputText::GetRawTextString, expected_input_text))))) | |
| .WillOnce(testing::Return(absl::OkStatus())); | |
| EXPECT_CALL(*mock_session_ptr, RunDecode(testing::_)) | |
| .WillOnce( | |
| testing::Return(Responses(TaskState::kProcessing, {"I am good."}))); | |
| ASSERT_OK_AND_ASSIGN(const Message response, | |
| conversation->SendMessage(user_messages)); | |
| JsonMessage assistant_message = nlohmann::ordered_json::parse(R"({ | |
| "role": "assistant", | |
| "content": [ | |
| { | |
| "type": "text", | |
| "text": "I am good." | |
| } | |
| ] | |
| })"); | |
| EXPECT_EQ(std::get<JsonMessage>(response), assistant_message); | |
| EXPECT_THAT(conversation->GetHistory(), | |
| testing::ElementsAre(user_messages[0], user_messages[1], | |
| assistant_message)); | |
| } | |
| TEST_P(ConversationTest, SendSingleMessageWithChannel) { | |
| // Set up mock Session. | |
| auto mock_session = CreateMockSession(); | |
| MockSession* mock_session_ptr = mock_session.get(); | |
| auto mock_engine = CreateMockEngine(std::move(mock_session)); | |
| // Create Conversation. | |
| ASSERT_OK_AND_ASSIGN( | |
| auto conversation_config, | |
| ConversationConfig::Builder() | |
| .SetSessionConfig(session_config_) | |
| .SetEnableConstrainedDecoding(enable_constrained_decoding_) | |
| .SetOverwritePromptTemplate(PromptTemplate(kTestJinjaPromptTemplate)) | |
| .SetPrefillPrefaceOnInit(prefill_preface_on_init_) | |
| .SetChannels({litert::lm::Channel{ | |
| .channel_name = "thought", | |
| .start = "<|channel>thought\n", | |
| .end = "<channel|>", | |
| }}) | |
| .Build(*mock_engine)); | |
| ASSERT_OK_AND_ASSIGN(auto conversation, | |
| Conversation::Create(*mock_engine, conversation_config)); | |
| // Send a single message. | |
| JsonMessage user_message = {{"role", "user"}, {"content", "How are you?"}}; | |
| absl::string_view expected_input_text = | |
| "<start_of_turn>user\n" | |
| "How are you?<end_of_turn>\n"; | |
| EXPECT_CALL(*mock_session_ptr, | |
| RunPrefill(testing::ElementsAre( | |
| testing::VariantWith<InputText>(testing::Property( | |
| &InputText::GetRawTextString, expected_input_text))))) | |
| .WillOnce(testing::Return(absl::OkStatus())); | |
| EXPECT_CALL(*mock_session_ptr, RunDecode(testing::_)) | |
| .WillOnce(testing::Return( | |
| Responses(TaskState::kProcessing, | |
| {"<|channel>thought\nhmm<channel|>I am good."}))); | |
| // Send the message. | |
| ASSERT_OK_AND_ASSIGN(const Message response, | |
| conversation->SendMessage(user_message)); | |
| JsonMessage assistant_message = nlohmann::ordered_json::parse(R"({ | |
| "role": "assistant", | |
| "content": [ | |
| { | |
| "type": "text", | |
| "text": "I am good." | |
| } | |
| ], | |
| "channels": { | |
| "thought": "hmm" | |
| } | |
| })"); | |
| EXPECT_THAT(std::get<JsonMessage>(response), testing::Eq(assistant_message)); | |
| EXPECT_THAT(conversation->GetHistory(), | |
| testing::ElementsAre(user_message, assistant_message)); | |
| } | |
| TEST_P(ConversationTest, SendSingleMessageWithChannelQwenThink) { | |
| // Set up mock Session. | |
| auto mock_session = CreateMockSession(); | |
| MockSession* mock_session_ptr = mock_session.get(); | |
| auto mock_engine = CreateMockEngine(std::move(mock_session)); | |
| // Create Conversation. | |
| ASSERT_OK_AND_ASSIGN( | |
| auto conversation_config, | |
| ConversationConfig::Builder() | |
| .SetSessionConfig(session_config_) | |
| .SetEnableConstrainedDecoding(enable_constrained_decoding_) | |
| .SetOverwritePromptTemplate(PromptTemplate(kTestJinjaPromptTemplate)) | |
| .SetPrefillPrefaceOnInit(prefill_preface_on_init_) | |
| .SetChannels({litert::lm::Channel{ | |
| .channel_name = "thought", | |
| .start = "<think>\n", | |
| .end = "\n</think>", | |
| }}) | |
| .Build(*mock_engine)); | |
| ASSERT_OK_AND_ASSIGN(auto conversation, | |
| Conversation::Create(*mock_engine, conversation_config)); | |
| // Send a single message. | |
| JsonMessage user_message = {{"role", "user"}, {"content", "How are you?"}}; | |
| absl::string_view expected_input_text = | |
| "<start_of_turn>user\n" | |
| "How are you?<end_of_turn>\n"; | |
| EXPECT_CALL(*mock_session_ptr, | |
| RunPrefill(testing::ElementsAre( | |
| testing::VariantWith<InputText>(testing::Property( | |
| &InputText::GetRawTextString, expected_input_text))))) | |
| .WillOnce(testing::Return(absl::OkStatus())); | |
| EXPECT_CALL(*mock_session_ptr, RunDecode(testing::_)) | |
| .WillOnce(testing::Return(Responses( | |
| TaskState::kProcessing, {"<think>\nhmm\n</think>I am good."}))); | |
| // Send the message. | |
| ASSERT_OK_AND_ASSIGN(const Message response, | |
| conversation->SendMessage(user_message)); | |
| JsonMessage assistant_message = nlohmann::ordered_json::parse(R"({ | |
| "role": "assistant", | |
| "content": [ | |
| { | |
| "type": "text", | |
| "text": "I am good." | |
| } | |
| ], | |
| "channels": { | |
| "thought": "hmm" | |
| } | |
| })"); | |
| EXPECT_THAT(std::get<JsonMessage>(response), testing::Eq(assistant_message)); | |
| EXPECT_THAT(conversation->GetHistory(), | |
| testing::ElementsAre(user_message, assistant_message)); | |
| } | |
| TEST_P(ConversationTest, SendMessageWithChannelContentFiltering) { | |
| auto mock_session = CreateMockSession(); | |
| MockSession* mock_session_ptr = mock_session.get(); | |
| auto mock_engine = CreateMockEngine(std::move(mock_session)); | |
| // Helper to get the raw text string from `InputText`. | |
| auto get_text = [](const InputText& it) -> std::string { | |
| auto status_or_view = it.GetRawTextString(); | |
| if (!status_or_view.ok()) return ""; | |
| return std::string(*status_or_view); | |
| }; | |
| ASSERT_OK_AND_ASSIGN( | |
| auto conversation_config, | |
| ConversationConfig::Builder() | |
| .SetSessionConfig(session_config_) | |
| .SetEnableConstrainedDecoding(enable_constrained_decoding_) | |
| .SetOverwritePromptTemplate(PromptTemplate(kTestJinjaPromptTemplate)) | |
| .SetPrefillPrefaceOnInit(prefill_preface_on_init_) | |
| .SetFilterChannelContentFromKvCache(true) | |
| .SetChannels({litert::lm::Channel{ | |
| .channel_name = "thought", | |
| .start = "<|channel>thought\n", | |
| .end = "<channel|>", | |
| }}) | |
| .Build(*mock_engine)); | |
| ASSERT_OK_AND_ASSIGN(auto conversation, | |
| Conversation::Create(*mock_engine, conversation_config)); | |
| // Expect prefill of first user message. | |
| EXPECT_CALL(*mock_session_ptr, RunPrefill(testing::_)) | |
| .WillOnce(Return(absl::OkStatus())); | |
| // Expect checkpoint to be saved after the first user message is prefilled. | |
| EXPECT_CALL(*mock_session_ptr, SaveCheckpoint("channel_content_checkpoint")) | |
| .WillOnce(Return(absl::OkStatus())); | |
| // Expect decode after first user message. Return response with channel | |
| // content. | |
| EXPECT_CALL(*mock_session_ptr, RunDecode(testing::_)) | |
| .WillOnce( | |
| Return(Responses(TaskState::kProcessing, | |
| {"<|channel>thought\nhmm<channel|>I am good."}))); | |
| // Send the first user message. | |
| JsonMessage user_message_1 = {{"role", "user"}, {"content", "How are you?"}}; | |
| ASSERT_OK(conversation->SendMessage(user_message_1)); | |
| // Expect rewind to checkpoint after second user message is sent. | |
| EXPECT_CALL(*mock_session_ptr, | |
| RewindToCheckpoint("channel_content_checkpoint")) | |
| .WillOnce(Return(absl::OkStatus())); | |
| // Expect previous assistant message and second user message to be prefilled | |
| // when the second user message is sent. The assistant message should not | |
| // have channel content. | |
| auto assistant_message_matcher = | |
| AllOf(HasSubstr("I am good."), Not(HasSubstr("hmm"))); | |
| EXPECT_CALL( | |
| *mock_session_ptr, | |
| RunPrefill(ElementsAre( | |
| VariantWith<InputText>(ResultOf(get_text, assistant_message_matcher)), | |
| VariantWith<InputText>( | |
| ResultOf(get_text, HasSubstr("That's great.")))))) | |
| .WillOnce(Return(absl::OkStatus())); | |
| // Expect a new checkpoint to be saved after the second user message is | |
| // prefilled. | |
| EXPECT_CALL(*mock_session_ptr, SaveCheckpoint("channel_content_checkpoint")) | |
| .WillOnce(Return(absl::OkStatus())); | |
| // Expect decode after second user message. | |
| EXPECT_CALL(*mock_session_ptr, RunDecode(testing::_)) | |
| .WillOnce(Return(Responses(TaskState::kProcessing, {"Thank you."}))); | |
| // Send the second user message. | |
| JsonMessage user_message_2 = {{"role", "user"}, {"content", "That's great."}}; | |
| ASSERT_OK(conversation->SendMessage(user_message_2)); | |
| } | |
| TEST_P(ConversationTest, SendMultipleMessagesWithHistory) { | |
| // Set up mock Session. | |
| auto mock_session = CreateMockSession(); | |
| MockSession* mock_session_ptr = mock_session.get(); | |
| auto mock_engine = CreateMockEngine(std::move(mock_session)); | |
| // Create Conversation. | |
| ASSERT_OK_AND_ASSIGN( | |
| auto conversation_config, | |
| ConversationConfig::Builder() | |
| .SetSessionConfig(session_config_) | |
| .SetEnableConstrainedDecoding(enable_constrained_decoding_) | |
| .SetOverwritePromptTemplate(PromptTemplate(kTestJinjaPromptTemplate)) | |
| .SetPrefillPrefaceOnInit(prefill_preface_on_init_) | |
| .Build(*mock_engine)); | |
| ASSERT_OK_AND_ASSIGN(auto conversation, | |
| Conversation::Create(*mock_engine, conversation_config)); | |
| // The first user message. | |
| JsonMessage user_message_1 = nlohmann::ordered_json::parse(R"json( | |
| { | |
| "role": "user", | |
| "content": "How are you?" | |
| } | |
| )json"); | |
| EXPECT_CALL(*mock_session_ptr, RunPrefill(testing::_)) | |
| .WillOnce(testing::Return(absl::OkStatus())); | |
| // The first assistant response. | |
| EXPECT_CALL(*mock_session_ptr, RunDecode(testing::_)) | |
| .WillOnce( | |
| testing::Return(Responses(TaskState::kProcessing, {"I am good."}))); | |
| // Send the first user message to fill the history. | |
| ASSERT_OK(conversation->SendMessage(user_message_1)); | |
| ASSERT_THAT(conversation->GetHistory().size(), testing::Eq(2)); | |
| // We will send two consecutive messages when the history is not empty. | |
| JsonMessage user_messages = nlohmann::ordered_json::parse(R"json( | |
| [ | |
| { | |
| "role": "user", | |
| "content": "foo" | |
| }, | |
| { | |
| "role": "user", | |
| "content": "bar" | |
| } | |
| ] | |
| )json"); | |
| absl::string_view expected_input_text = | |
| "<start_of_turn>user\n" | |
| "foo<end_of_turn>\n" | |
| "<start_of_turn>user\n" | |
| "bar<end_of_turn>\n"; | |
| EXPECT_CALL(*mock_session_ptr, | |
| RunPrefill(testing::ElementsAre( | |
| testing::VariantWith<InputText>(testing::Property( | |
| &InputText::GetRawTextString, expected_input_text))))) | |
| .WillOnce(testing::Return(absl::OkStatus())); | |
| // The second assistant response. | |
| EXPECT_CALL(*mock_session_ptr, RunDecode(testing::_)) | |
| .WillOnce(testing::Return(Responses(TaskState::kProcessing, {"baz"}))); | |
| // Send the user messages. | |
| ASSERT_OK(conversation->SendMessage(user_messages)); | |
| // Check the history. | |
| JsonMessage assistant_message_1 = nlohmann::ordered_json::parse(R"({ | |
| "role": "assistant", | |
| "content": [ | |
| { | |
| "type": "text", | |
| "text": "I am good." | |
| } | |
| ] | |
| })"); | |
| JsonMessage assistant_message_2 = nlohmann::ordered_json::parse(R"({ | |
| "role": "assistant", | |
| "content": [ | |
| { | |
| "type": "text", | |
| "text": "baz" | |
| } | |
| ] | |
| })"); | |
| EXPECT_THAT(conversation->GetHistory(), | |
| testing::ElementsAre(user_message_1, assistant_message_1, | |
| user_messages[0], user_messages[1], | |
| assistant_message_2)); | |
| } | |
| TEST_P(ConversationTest, RunTextScoring) { | |
| // Set up mock Session. | |
| auto mock_session = CreateMockSession(); | |
| MockSession* mock_session_ptr = mock_session.get(); | |
| auto mock_engine = CreateMockEngine(std::move(mock_session)); | |
| // Create Conversation. | |
| ASSERT_OK_AND_ASSIGN( | |
| auto conversation_config, | |
| ConversationConfig::Builder() | |
| .SetSessionConfig(session_config_) | |
| .SetOverwritePromptTemplate(PromptTemplate(kTestJinjaPromptTemplate)) | |
| .Build(*mock_engine)); | |
| ASSERT_OK_AND_ASSIGN(auto conversation, | |
| Conversation::Create(*mock_engine, conversation_config)); | |
| // Test sync scoring. | |
| auto cloned_session_sync = std::make_unique<MockSession>(); | |
| EXPECT_CALL(*cloned_session_sync, | |
| RunTextScoring(testing::ElementsAre("I am good."), true)) | |
| .WillOnce( | |
| testing::Return(Responses(TaskState::kProcessing, {"I am good."}))); | |
| EXPECT_CALL(*mock_session_ptr, Clone()) | |
| .WillOnce(testing::Return(std::move(cloned_session_sync))); | |
| ASSERT_OK_AND_ASSIGN(const Responses response, | |
| conversation->RunTextScoring({"I am good."})); | |
| EXPECT_EQ(response.GetTexts()[0], "I am good."); | |
| // Test async scoring. | |
| auto cloned_session_async = std::make_unique<MockSession>(); | |
| EXPECT_CALL( | |
| *cloned_session_async, | |
| RunTextScoringAsync(testing::ElementsAre("I am good."), testing::_, true)) | |
| .WillOnce([](const std::vector<absl::string_view>& target_text, | |
| absl::AnyInvocable<void(absl::StatusOr<Responses>)> callback, | |
| bool store_token_lengths) { | |
| callback(Responses(TaskState::kProcessing, {"I am good."})); | |
| return nullptr; | |
| }); | |
| EXPECT_CALL(*mock_session_ptr, CloneAsync(testing::_)) | |
| .WillOnce(testing::Return(std::move(cloned_session_async))); | |
| absl::Notification done; | |
| std::string response_text; | |
| EXPECT_OK(conversation->RunTextScoringAsync( | |
| {"I am good."}, [&](absl::StatusOr<Responses> responses) { | |
| ASSERT_OK(responses); | |
| response_text = responses->GetTexts()[0]; | |
| done.Notify(); | |
| })); | |
| done.WaitForNotificationWithTimeout(absl::Seconds(10)); | |
| EXPECT_EQ(response_text, "I am good."); | |
| } | |
| TEST_P(ConversationTest, SendMessageAsync) { | |
| ASSERT_OK_AND_ASSIGN(auto model_assets, | |
| ModelAssets::Create(GetTestdataPath(kTestLlmPath))); | |
| ASSERT_OK_AND_ASSIGN(auto engine_settings, EngineSettings::CreateDefault( | |
| model_assets, Backend::CPU)); | |
| engine_settings.GetMutableMainExecutorSettings().SetCacheDir(":nocache"); | |
| engine_settings.GetMutableMainExecutorSettings().SetMaxNumTokens(10); | |
| ASSERT_OK_AND_ASSIGN(auto engine, EngineFactory::CreateAny(engine_settings)); | |
| ASSERT_OK_AND_ASSIGN( | |
| auto config, | |
| ConversationConfig::Builder() | |
| .SetEnableConstrainedDecoding(enable_constrained_decoding_) | |
| .SetPrefillPrefaceOnInit(prefill_preface_on_init_) | |
| .Build(*engine)); | |
| ASSERT_OK_AND_ASSIGN(auto conversation, | |
| Conversation::Create(*engine, config)); | |
| JsonMessage user_message = {{"role", "user"}, {"content", "Hello world!"}}; | |
| // The expected message is just some gibberish text, because the test LLM has | |
| // random weights. | |
| Message expected_message = | |
| JsonMessage({{"role", "assistant"}, | |
| {"content", | |
| {{{"type", "text"}, | |
| {"text", "TarefaByte دارایेत्र investigaciónప్రదేశ"}}}}}); | |
| Message expected_message_for_confirm = expected_message; | |
| absl::Notification done; | |
| EXPECT_OK(conversation->SendMessageAsync( | |
| user_message, CreateTestMessageCallback(expected_message, done))); | |
| // Wait for the async message to be processed. | |
| EXPECT_OK(engine->WaitUntilDone(absl::Seconds(100))); | |
| done.WaitForNotificationWithTimeout(absl::Seconds(10)); | |
| EXPECT_THAT(conversation->GetHistory(), | |
| testing::ElementsAre(user_message, expected_message_for_confirm)); | |
| } | |
| TEST_P(ConversationTest, SendSingleMessageAsync) { | |
| // Set up mock Session. | |
| auto mock_session = CreateMockSession(); | |
| MockSession* mock_session_ptr = mock_session.get(); | |
| auto mock_engine = CreateMockEngine(std::move(mock_session)); | |
| // Create Conversation. | |
| ASSERT_OK_AND_ASSIGN( | |
| auto conversation_config, | |
| ConversationConfig::Builder() | |
| .SetSessionConfig(session_config_) | |
| .SetOverwritePromptTemplate(PromptTemplate(kTestJinjaPromptTemplate)) | |
| .Build(*mock_engine)); | |
| ASSERT_OK_AND_ASSIGN(auto conversation, | |
| Conversation::Create(*mock_engine, conversation_config)); | |
| // We will send a single message. | |
| JsonMessage user_message = {{"role", "user"}, {"content", "How are you?"}}; | |
| absl::string_view expected_input_text = | |
| "<start_of_turn>user\n" | |
| "How are you?<end_of_turn>\n"; | |
| EXPECT_CALL( | |
| *mock_session_ptr, | |
| RunPrefillAsync(testing::ElementsAre(testing::VariantWith<InputText>( | |
| testing::Property(&InputText::GetRawTextString, | |
| expected_input_text))), | |
| testing::_)) | |
| .WillOnce([](const std::vector<InputData>& contents, | |
| absl::AnyInvocable<void(absl::StatusOr<Responses>)> | |
| user_callback) { | |
| user_callback(Responses(TaskState::kDone)); | |
| return nullptr; | |
| }); | |
| EXPECT_CALL(*mock_session_ptr, RunDecodeAsync(testing::_, testing::_)) | |
| .WillOnce( | |
| [](absl::AnyInvocable<void(absl::StatusOr<Responses>)> user_callback, | |
| const DecodeConfig& decode_config) { | |
| user_callback(Responses(TaskState::kProcessing, {"I am good."})); | |
| user_callback(Responses(TaskState::kDone)); | |
| return nullptr; | |
| }); | |
| Message assistant_message = JsonMessage(nlohmann::ordered_json::parse(R"({ | |
| "role": "assistant", | |
| "content": [ | |
| { | |
| "type": "text", | |
| "text": "I am good." | |
| } | |
| ] | |
| })")); | |
| Message assistant_message_for_confirm = assistant_message; | |
| absl::Notification done; | |
| auto message_callback = CreateTestMessageCallback(assistant_message, done); | |
| EXPECT_OK(conversation->SendMessageAsync(user_message, | |
| std::move(message_callback))); | |
| done.WaitForNotificationWithTimeout(absl::Seconds(10)); | |
| EXPECT_THAT( | |
| conversation->GetHistory(), | |
| testing::ElementsAre(user_message, assistant_message_for_confirm)); | |
| } | |
| TEST_P(ConversationTest, SendMessageAsyncWithChannelContent) { | |
| auto mock_session = CreateMockSession(); | |
| MockSession* mock_session_ptr = mock_session.get(); | |
| auto mock_engine = CreateMockEngine(std::move(mock_session)); | |
| std::vector<Channel> custom_channels = {{"thought", "<think>", "</think>"}}; | |
| ASSERT_OK_AND_ASSIGN( | |
| auto conversation_config, | |
| ConversationConfig::Builder() | |
| .SetSessionConfig(session_config_) | |
| .SetOverwritePromptTemplate(PromptTemplate(kTestJinjaPromptTemplate)) | |
| .SetChannels(custom_channels) | |
| .Build(*mock_engine)); | |
| ASSERT_OK_AND_ASSIGN(auto conversation, | |
| Conversation::Create(*mock_engine, conversation_config)); | |
| JsonMessage user_message = {{"role", "user"}, {"content", "How are you?"}}; | |
| absl::string_view expected_input_text = | |
| "<start_of_turn>user\n" | |
| "How are you?<end_of_turn>\n"; | |
| EXPECT_CALL( | |
| *mock_session_ptr, | |
| RunPrefillAsync(testing::ElementsAre(testing::VariantWith<InputText>( | |
| testing::Property(&InputText::GetRawTextString, | |
| expected_input_text))), | |
| testing::_)) | |
| .WillOnce([](const std::vector<InputData>& contents, | |
| absl::AnyInvocable<void(absl::StatusOr<Responses>)> | |
| user_callback) { | |
| user_callback(Responses(TaskState::kDone)); | |
| return nullptr; | |
| }); | |
| EXPECT_CALL(*mock_session_ptr, RunDecodeAsync(testing::_, testing::_)) | |
| .WillOnce( | |
| [](absl::AnyInvocable<void(absl::StatusOr<Responses>)> user_callback, | |
| const DecodeConfig& decode_config) { | |
| user_callback(Responses(TaskState::kProcessing, | |
| {"Hello <think>hmm</think> World!"})); | |
| user_callback(Responses(TaskState::kDone)); | |
| return nullptr; | |
| }); | |
| absl::Notification done; | |
| std::vector<Message> expected_messages = { | |
| JsonMessage{{"role", "assistant"}, | |
| {"content", {{{"type", "text"}, {"text", "Hello "}}}}}, | |
| JsonMessage{{"role", "assistant"}, {"channels", {{"thought", "hmm"}}}}, | |
| JsonMessage{{"role", "assistant"}, | |
| {"content", {{{"type", "text"}, {"text", " World!"}}}}}, | |
| }; | |
| auto message_callback = | |
| CreateTestMultiMessageCallback(expected_messages, done); | |
| EXPECT_OK(conversation->SendMessageAsync(user_message, | |
| std::move(message_callback))); | |
| done.WaitForNotificationWithTimeout(absl::Seconds(10)); | |
| // Verify the final message in history. | |
| JsonMessage expected_assistant_message = nlohmann::ordered_json::parse(R"({ | |
| "role": "assistant", | |
| "content": [ | |
| { | |
| "type": "text", | |
| "text": "Hello World!" | |
| } | |
| ], | |
| "channels": { | |
| "thought": "hmm" | |
| } | |
| })"); | |
| EXPECT_THAT(conversation->GetHistory(), | |
| testing::ElementsAre(user_message, expected_assistant_message)); | |
| } | |
| TEST_P(ConversationTest, SendSingleMessageAsyncWithExtraContext) { | |
| // Set up mock Session. | |
| auto mock_session = CreateMockSession(); | |
| MockSession* mock_session_ptr = mock_session.get(); | |
| auto mock_engine = CreateMockEngine(std::move(mock_session)); | |
| // Create Conversation and overwrite prompt template. | |
| absl::string_view prompt_template = R"jinja( | |
| {%- if enable_thinking -%} | |
| <start_of_turn>system | |
| Thinking enabled.<end_of_turn> | |
| {% else %} | |
| <start_of_turn>system | |
| Thinking disabled.<end_of_turn> | |
| {%- endif -%} | |
| {%- for message in messages -%} | |
| {{- '<start_of_turn>' + message.role + '\n' -}} | |
| {%- if message.content is string -%} | |
| {{- message.content + '<end_of_turn>\n' -}} | |
| {%- else -%} | |
| {{- message.content[0].text + '<end_of_turn>\n' -}} | |
| {%- endif -%} | |
| {%- endfor -%} | |
| )jinja"; | |
| ASSERT_OK_AND_ASSIGN( | |
| auto conversation_config, | |
| ConversationConfig::Builder() | |
| .SetSessionConfig(session_config_) | |
| .SetOverwritePromptTemplate(PromptTemplate(prompt_template)) | |
| .Build(*mock_engine)); | |
| ASSERT_OK_AND_ASSIGN(auto conversation, | |
| Conversation::Create(*mock_engine, conversation_config)); | |
| // We will send a single message. | |
| JsonMessage user_message = {{"role", "user"}, {"content", "How are you?"}}; | |
| OptionalArgs optional_args; | |
| optional_args.extra_context = absl::flat_hash_map<std::string, std::string>{ | |
| {"enable_thinking", "true"}}; | |
| absl::string_view expected_input_text = | |
| "<start_of_turn>system\nThinking enabled.<end_of_turn>\n" | |
| "<start_of_turn>user\n" | |
| "How are you?<end_of_turn>\n"; | |
| EXPECT_CALL( | |
| *mock_session_ptr, | |
| RunPrefillAsync(testing::ElementsAre(testing::VariantWith<InputText>( | |
| testing::Property(&InputText::GetRawTextString, | |
| expected_input_text))), | |
| testing::_)) | |
| .WillOnce([](const std::vector<InputData>& contents, | |
| absl::AnyInvocable<void(absl::StatusOr<Responses>)> | |
| user_callback) { | |
| user_callback(Responses(TaskState::kDone)); | |
| return nullptr; | |
| }); | |
| EXPECT_CALL(*mock_session_ptr, RunDecodeAsync(testing::_, testing::_)) | |
| .WillOnce( | |
| [](absl::AnyInvocable<void(absl::StatusOr<Responses>)> user_callback, | |
| const DecodeConfig& decode_config) { | |
| user_callback( | |
| Responses(TaskState::kProcessing, {"I am good async."})); | |
| user_callback(Responses(TaskState::kDone)); | |
| return nullptr; | |
| }); | |
| Message assistant_message = JsonMessage(nlohmann::ordered_json::parse(R"({ | |
| "role": "assistant", | |
| "content": [ | |
| { | |
| "type": "text", | |
| "text": "I am good async." | |
| } | |
| ] | |
| })")); | |
| Message assistant_message_for_confirm = assistant_message; | |
| absl::Notification done; | |
| auto message_callback = CreateTestMessageCallback(assistant_message, done); | |
| EXPECT_OK(conversation->SendMessageAsync( | |
| user_message, std::move(message_callback), std::move(optional_args))); | |
| done.WaitForNotificationWithTimeout(absl::Seconds(10)); | |
| EXPECT_THAT( | |
| conversation->GetHistory(), | |
| testing::ElementsAre(user_message, assistant_message_for_confirm)); | |
| } | |
| TEST_P(ConversationTest, SendMultipleMessagesAsync) { | |
| // Set up mock Session. | |
| auto mock_session = CreateMockSession(); | |
| MockSession* mock_session_ptr = mock_session.get(); | |
| auto mock_engine = CreateMockEngine(std::move(mock_session)); | |
| // Create Conversation. | |
| ASSERT_OK_AND_ASSIGN( | |
| auto conversation_config, | |
| ConversationConfig::Builder() | |
| .SetSessionConfig(session_config_) | |
| .SetEnableConstrainedDecoding(enable_constrained_decoding_) | |
| .SetPrefillPrefaceOnInit(prefill_preface_on_init_) | |
| .SetOverwritePromptTemplate(PromptTemplate(kTestJinjaPromptTemplate)) | |
| .Build(*mock_engine)); | |
| ASSERT_OK_AND_ASSIGN(auto conversation, | |
| Conversation::Create(*mock_engine, conversation_config)); | |
| // We will send two consecutive messages. | |
| JsonMessage user_messages = nlohmann::ordered_json::parse(R"json( | |
| [ | |
| { | |
| "role": "user", | |
| "content": "Hello world!" | |
| }, | |
| { | |
| "role": "user", | |
| "content": "How are you?" | |
| } | |
| ] | |
| )json"); | |
| absl::string_view expected_input_text = | |
| "<start_of_turn>user\n" | |
| "Hello world!<end_of_turn>\n" | |
| "<start_of_turn>user\n" | |
| "How are you?<end_of_turn>\n"; | |
| EXPECT_CALL( | |
| *mock_session_ptr, | |
| RunPrefillAsync(testing::ElementsAre(testing::VariantWith<InputText>( | |
| testing::Property(&InputText::GetRawTextString, | |
| expected_input_text))), | |
| testing::_)) | |
| .WillOnce([](const std::vector<InputData>& contents, | |
| absl::AnyInvocable<void(absl::StatusOr<Responses>)> | |
| user_callback) { | |
| user_callback(Responses(TaskState::kDone)); | |
| return nullptr; | |
| }); | |
| EXPECT_CALL(*mock_session_ptr, RunDecodeAsync(testing::_, testing::_)) | |
| .WillOnce( | |
| [](absl::AnyInvocable<void(absl::StatusOr<Responses>)> user_callback, | |
| const DecodeConfig& decode_config) { | |
| user_callback(Responses(TaskState::kProcessing, {"I am good."})); | |
| user_callback(Responses(TaskState::kDone)); | |
| return nullptr; | |
| }); | |
| Message assistant_message = JsonMessage(nlohmann::ordered_json::parse(R"json({ | |
| "role": "assistant", | |
| "content": [ | |
| { | |
| "type": "text", | |
| "text": "I am good." | |
| } | |
| ] | |
| })json")); | |
| Message assistant_message_for_confirm = assistant_message; | |
| absl::Notification done; | |
| auto message_callback = CreateTestMessageCallback(assistant_message, done); | |
| EXPECT_OK(conversation->SendMessageAsync(user_messages, | |
| std::move(message_callback))); | |
| done.WaitForNotificationWithTimeout(absl::Seconds(10)); | |
| EXPECT_THAT(conversation->GetHistory(), | |
| testing::ElementsAre(user_messages[0], user_messages[1], | |
| assistant_message_for_confirm)); | |
| } | |
| TEST_P(ConversationTest, SendMessageAsyncWithChannelContentFiltering) { | |
| // Set up mock Session. | |
| auto mock_session = CreateMockSession(); | |
| MockSession* mock_session_ptr = mock_session.get(); | |
| auto mock_engine = CreateMockEngine(std::move(mock_session)); | |
| // Create Conversation with channel content filtering enabled. | |
| ASSERT_OK_AND_ASSIGN( | |
| auto conversation_config, | |
| ConversationConfig::Builder() | |
| .SetSessionConfig(session_config_) | |
| .SetEnableConstrainedDecoding(enable_constrained_decoding_) | |
| .SetPrefillPrefaceOnInit(prefill_preface_on_init_) | |
| .SetOverwritePromptTemplate(PromptTemplate(kTestJinjaPromptTemplate)) | |
| .SetFilterChannelContentFromKvCache(true) | |
| .SetChannels({litert::lm::Channel{ | |
| .channel_name = "thought", | |
| .start = "<|channel>thought\n", | |
| .end = "<channel|>", | |
| }}) | |
| .Build(*mock_engine)); | |
| ASSERT_OK_AND_ASSIGN(auto conversation, | |
| Conversation::Create(*mock_engine, conversation_config)); | |
| // Helper to get the raw text string from `InputText`. | |
| auto get_text = [](const InputText& it) -> std::string { | |
| auto status_or_view = it.GetRawTextString(); | |
| if (!status_or_view.ok()) return ""; | |
| return std::string(*status_or_view); | |
| }; | |
| // Expect prefill of first user message. | |
| EXPECT_CALL(*mock_session_ptr, RunPrefillAsync(testing::_, testing::_)) | |
| .WillOnce([](const std::vector<InputData>& contents, | |
| absl::AnyInvocable<void(absl::StatusOr<Responses>)> | |
| user_callback) { | |
| user_callback(Responses(TaskState::kDone)); | |
| return nullptr; | |
| }); | |
| // Expect checkpoint to be saved. | |
| EXPECT_CALL(*mock_session_ptr, SaveCheckpoint("channel_content_checkpoint")) | |
| .WillOnce(Return(absl::OkStatus())); | |
| // Expect decode after first user message. | |
| EXPECT_CALL(*mock_session_ptr, RunDecodeAsync(testing::_, testing::_)) | |
| .WillOnce( | |
| [](absl::AnyInvocable<void(absl::StatusOr<Responses>)> user_callback, | |
| const DecodeConfig& decode_config) { | |
| user_callback( | |
| Responses(TaskState::kProcessing, | |
| {"<|channel>thought\nhmm<channel|>I am good."})); | |
| user_callback(Responses(TaskState::kDone)); | |
| return nullptr; | |
| }); | |
| // Prepare the first user message. | |
| JsonMessage user_message_1 = {{"role", "user"}, {"content", "How are you?"}}; | |
| Message assistant_message_1 = | |
| JsonMessage(nlohmann::ordered_json::parse(R"json({ | |
| "role": "assistant", | |
| "content": [ | |
| { | |
| "type": "text", | |
| "text": "I am good." | |
| } | |
| ] | |
| })json")); | |
| absl::Notification done_1; | |
| auto message_callback_1 = [&done_1](absl::StatusOr<Message> message) { | |
| if (!message.ok()) { | |
| done_1.Notify(); | |
| return; | |
| } | |
| if (auto* json_msg = std::get_if<JsonMessage>(&message.value())) { | |
| if (json_msg->is_null()) { | |
| done_1.Notify(); | |
| } | |
| } | |
| }; | |
| // Send the first user message. | |
| EXPECT_OK(conversation->SendMessageAsync(user_message_1, | |
| std::move(message_callback_1))); | |
| ASSERT_TRUE(done_1.WaitForNotificationWithTimeout(absl::Seconds(10))); | |
| // Prepare the second user message. | |
| JsonMessage user_message_2 = {{"role", "user"}, {"content", "That's great."}}; | |
| absl::Notification done_2; | |
| Message assistant_message_2 = | |
| JsonMessage(nlohmann::ordered_json::parse(R"json({ | |
| "role": "assistant", | |
| "content": [ | |
| { | |
| "type": "text", | |
| "text": "Indeed." | |
| } | |
| ] | |
| })json")); | |
| auto message_callback_2 = [&done_2](absl::StatusOr<Message> message) { | |
| if (!message.ok()) { | |
| done_2.Notify(); | |
| return; | |
| } | |
| if (auto* json_msg = std::get_if<JsonMessage>(&message.value())) { | |
| if (json_msg->is_null()) { | |
| done_2.Notify(); | |
| } | |
| } | |
| }; | |
| // Expect rewind to checkpoint when second user message is sent. | |
| EXPECT_CALL(*mock_session_ptr, | |
| RewindToCheckpoint("channel_content_checkpoint")) | |
| .WillOnce(Return(absl::OkStatus())); | |
| // Expect the previous assistant message and the new user message to be | |
| // prefilled asynchronously. The previous assistant message should not contain | |
| // channel content. | |
| auto assistant_message_matcher = | |
| AllOf(HasSubstr("I am good."), Not(HasSubstr("hmm"))); | |
| auto message_input_matcher = ElementsAre( | |
| VariantWith<InputText>(ResultOf(get_text, assistant_message_matcher)), | |
| VariantWith<InputText>(ResultOf(get_text, HasSubstr("That's great.")))); | |
| EXPECT_CALL(*mock_session_ptr, | |
| RunPrefillAsync(message_input_matcher, testing::_)) | |
| .WillOnce([](const std::vector<InputData>& contents, | |
| absl::AnyInvocable<void(absl::StatusOr<Responses>)> | |
| user_callback) { | |
| user_callback(Responses(TaskState::kDone)); | |
| return nullptr; | |
| }); | |
| // Expect a new checkpoint to be saved before decode of second turn. | |
| EXPECT_CALL(*mock_session_ptr, SaveCheckpoint("channel_content_checkpoint")) | |
| .WillOnce(Return(absl::OkStatus())); | |
| // Expect decode after second user message. | |
| EXPECT_CALL(*mock_session_ptr, RunDecodeAsync(testing::_, testing::_)) | |
| .WillOnce( | |
| [](absl::AnyInvocable<void(absl::StatusOr<Responses>)> user_callback, | |
| const DecodeConfig& decode_config) { | |
| user_callback(Responses(TaskState::kProcessing, {"Indeed."})); | |
| user_callback(Responses(TaskState::kDone)); | |
| return nullptr; | |
| }); | |
| // Send the second user message. | |
| EXPECT_OK(conversation->SendMessageAsync(user_message_2, | |
| std::move(message_callback_2))); | |
| ASSERT_TRUE(done_2.WaitForNotificationWithTimeout(absl::Seconds(10))); | |
| } | |
| TEST_P(ConversationTest, SendMultipleMessagesAsyncWithHistory) { | |
| // Set up mock Session. | |
| auto mock_session = CreateMockSession(); | |
| MockSession* mock_session_ptr = mock_session.get(); | |
| auto mock_engine = CreateMockEngine(std::move(mock_session)); | |
| // Create Conversation. | |
| ASSERT_OK_AND_ASSIGN( | |
| auto conversation_config, | |
| ConversationConfig::Builder() | |
| .SetSessionConfig(session_config_) | |
| .SetEnableConstrainedDecoding(enable_constrained_decoding_) | |
| .SetOverwritePromptTemplate(PromptTemplate(kTestJinjaPromptTemplate)) | |
| .Build(*mock_engine)); | |
| ASSERT_OK_AND_ASSIGN(auto conversation, | |
| Conversation::Create(*mock_engine, conversation_config)); | |
| // The first user message. | |
| JsonMessage user_message_1 = nlohmann::ordered_json::parse(R"json( | |
| { | |
| "role": "user", | |
| "content": "How are you?" | |
| } | |
| )json"); | |
| absl::string_view expected_input_text1 = | |
| "<start_of_turn>user\n" | |
| "How are you?<end_of_turn>\n"; | |
| EXPECT_CALL( | |
| *mock_session_ptr, | |
| RunPrefillAsync(testing::ElementsAre(testing::VariantWith<InputText>( | |
| testing::Property(&InputText::GetRawTextString, | |
| expected_input_text1))), | |
| testing::_)) | |
| .WillOnce([](const std::vector<InputData>& contents, | |
| absl::AnyInvocable<void(absl::StatusOr<Responses>)> | |
| user_callback) { | |
| user_callback(Responses(TaskState::kDone)); | |
| return nullptr; | |
| }); | |
| EXPECT_CALL(*mock_session_ptr, RunDecodeAsync(testing::_, testing::_)) | |
| .WillOnce( | |
| [](absl::AnyInvocable<void(absl::StatusOr<Responses>)> user_callback, | |
| const DecodeConfig& decode_config) { | |
| user_callback(Responses(TaskState::kProcessing, {"I am good."})); | |
| user_callback(Responses(TaskState::kDone)); | |
| return nullptr; | |
| }); | |
| Message assistant_message_1 = | |
| JsonMessage(nlohmann::ordered_json::parse(R"json({ | |
| "role": "assistant", | |
| "content": [ | |
| { | |
| "type": "text", | |
| "text": "I am good." | |
| } | |
| ] | |
| })json")); | |
| Message assistant_message_1_for_confirm = assistant_message_1; | |
| absl::Notification done_1; | |
| EXPECT_OK(conversation->SendMessageAsync( | |
| user_message_1, CreateTestMessageCallback(assistant_message_1, done_1))); | |
| done_1.WaitForNotificationWithTimeout(absl::Seconds(10)); | |
| ASSERT_THAT(conversation->GetHistory().size(), testing::Eq(2)); | |
| // We will send two consecutive messages when the history is not empty. | |
| JsonMessage user_messages = nlohmann::ordered_json::parse(R"json( | |
| [ | |
| { | |
| "role": "user", | |
| "content": "foo" | |
| }, | |
| { | |
| "role": "user", | |
| "content": "bar" | |
| } | |
| ] | |
| )json"); | |
| absl::string_view expected_input_text2 = | |
| "<start_of_turn>user\n" | |
| "foo<end_of_turn>\n" | |
| "<start_of_turn>user\n" | |
| "bar<end_of_turn>\n"; | |
| EXPECT_CALL( | |
| *mock_session_ptr, | |
| RunPrefillAsync(testing::ElementsAre(testing::VariantWith<InputText>( | |
| testing::Property(&InputText::GetRawTextString, | |
| expected_input_text2))), | |
| testing::_)) | |
| .WillOnce([](const std::vector<InputData>& contents, | |
| absl::AnyInvocable<void(absl::StatusOr<Responses>)> | |
| user_callback) { | |
| user_callback(Responses(TaskState::kDone)); | |
| return nullptr; | |
| }); | |
| EXPECT_CALL(*mock_session_ptr, RunDecodeAsync(testing::_, testing::_)) | |
| .WillOnce( | |
| [](absl::AnyInvocable<void(absl::StatusOr<Responses>)> user_callback, | |
| const DecodeConfig& decode_config) { | |
| user_callback(Responses(TaskState::kProcessing, {"baz"})); | |
| user_callback(Responses(TaskState::kDone)); | |
| return nullptr; | |
| }); | |
| Message assistant_message_2 = | |
| JsonMessage(nlohmann::ordered_json::parse(R"json({ | |
| "role": "assistant", | |
| "content": [ | |
| { | |
| "type": "text", | |
| "text": "baz" | |
| } | |
| ] | |
| })json")); | |
| Message assistant_message_2_for_confirm = assistant_message_2; | |
| absl::Notification done_2; | |
| auto message_callbacks_2 = | |
| CreateTestMessageCallback(assistant_message_2, done_2); | |
| EXPECT_OK(conversation->SendMessageAsync(user_messages, | |
| std::move(message_callbacks_2))); | |
| done_2.WaitForNotificationWithTimeout(absl::Seconds(10)); | |
| EXPECT_THAT( | |
| conversation->GetHistory(), | |
| testing::ElementsAre(user_message_1, assistant_message_1_for_confirm, | |
| user_messages[0], user_messages[1], | |
| assistant_message_2_for_confirm)); | |
| } | |
| TEST_P(ConversationTest, SendMessageWithPreface) { | |
| ASSERT_OK_AND_ASSIGN(auto model_assets, | |
| ModelAssets::Create(GetTestdataPath(kTestLlmPath))); | |
| ASSERT_OK_AND_ASSIGN(auto engine_settings, EngineSettings::CreateDefault( | |
| model_assets, Backend::CPU)); | |
| engine_settings.GetMutableMainExecutorSettings().SetCacheDir(":nocache"); | |
| engine_settings.GetMutableMainExecutorSettings().SetMaxNumTokens(15); | |
| ASSERT_OK_AND_ASSIGN(auto engine, EngineFactory::CreateAny(engine_settings)); | |
| ASSERT_OK_AND_ASSIGN( | |
| auto config, | |
| ConversationConfig::Builder() | |
| .SetPreface(JsonPreface{ | |
| .messages = {{{"role", "system"}, | |
| {"content", "You are a helpful assistant."}}}}) | |
| .SetEnableConstrainedDecoding(enable_constrained_decoding_) | |
| .SetPrefillPrefaceOnInit(prefill_preface_on_init_) | |
| .Build(*engine)); | |
| ASSERT_OK_AND_ASSIGN(auto conversation, | |
| Conversation::Create(*engine, config)); | |
| ASSERT_OK_AND_ASSIGN(const Message message, | |
| conversation->SendMessage(JsonMessage{ | |
| {"role", "user"}, {"content", "Hello world!"}})); | |
| // The expected message is just some gibberish text, because the test LLM has | |
| // random weights. | |
| JsonMessage expected_message; | |
| if (prefill_preface_on_init_) { | |
| expected_message = {{"role", "assistant"}, | |
| {"content", | |
| {{{"type", "text"}, | |
| {"text", " rupani rupani rupani echoes echoes"}}}}}; | |
| } else { | |
| expected_message = { | |
| {"role", "assistant"}, | |
| {"content", | |
| {{{"type", "text"}, | |
| {"text", " noses</caption> গ্রাহ<unused5296> omp"}}}}}; | |
| } | |
| const JsonMessage& json_message = std::get<JsonMessage>(message); | |
| EXPECT_EQ(json_message, expected_message); | |
| } | |
| TEST_P(ConversationTest, GetBenchmarkInfo) { | |
| ASSERT_OK_AND_ASSIGN(auto model_assets, | |
| ModelAssets::Create(GetTestdataPath(kTestLlmPath))); | |
| ASSERT_OK_AND_ASSIGN(auto engine_settings, EngineSettings::CreateDefault( | |
| model_assets, Backend::CPU)); | |
| engine_settings.GetMutableMainExecutorSettings().SetCacheDir(":nocache"); | |
| engine_settings.GetMutableMainExecutorSettings().SetMaxNumTokens(15); | |
| proto::BenchmarkParams benchmark_params; | |
| engine_settings.GetMutableBenchmarkParams() = benchmark_params; | |
| ASSERT_OK_AND_ASSIGN(auto engine, EngineFactory::CreateAny(engine_settings)); | |
| ASSERT_OK_AND_ASSIGN( | |
| auto config, | |
| ConversationConfig::Builder() | |
| .SetPreface(JsonPreface{ | |
| .messages = {{{"role", "system"}, | |
| {"content", "You are a helpful assistant."}}}}) | |
| .SetEnableConstrainedDecoding(enable_constrained_decoding_) | |
| .SetPrefillPrefaceOnInit(prefill_preface_on_init_) | |
| .Build(*engine)); | |
| ASSERT_OK_AND_ASSIGN(auto conversation, | |
| Conversation::Create(*engine, config)); | |
| ASSERT_OK_AND_ASSIGN(const Message message_1, | |
| conversation->SendMessage(JsonMessage{ | |
| {"role", "user"}, {"content", "Hello world!"}})); | |
| ASSERT_OK_AND_ASSIGN(const BenchmarkInfo benchmark_info_1, | |
| conversation->GetBenchmarkInfo()); | |
| EXPECT_EQ(benchmark_info_1.GetTotalPrefillTurns(), | |
| prefill_preface_on_init_ ? 2 : 1); | |
| ASSERT_OK_AND_ASSIGN(const Message message_2, | |
| conversation->SendMessage(JsonMessage{ | |
| {"role", "user"}, {"content", "Hello world!"}})); | |
| ASSERT_OK_AND_ASSIGN(const BenchmarkInfo benchmark_info_2, | |
| conversation->GetBenchmarkInfo()); | |
| EXPECT_EQ(benchmark_info_2.GetTotalPrefillTurns(), | |
| prefill_preface_on_init_ ? 3 : 2); | |
| } | |
| TEST_P(ConversationTest, CancelGroupWithSendMessageAsync) { | |
| // Set up mock Session. | |
| auto mock_session = CreateMockSession(); | |
| MockSession* mock_session_ptr = mock_session.get(); | |
| auto mock_engine = CreateMockEngine(std::move(mock_session)); | |
| // Create Conversation. | |
| ASSERT_OK_AND_ASSIGN( | |
| auto conversation_config, | |
| ConversationConfig::Builder() | |
| .SetSessionConfig(session_config_) | |
| .SetOverwritePromptTemplate(PromptTemplate(kTestJinjaPromptTemplate)) | |
| .Build(*mock_engine)); | |
| ASSERT_OK_AND_ASSIGN(auto conversation, | |
| Conversation::Create(*mock_engine, conversation_config)); | |
| // We will send a single message. | |
| JsonMessage user_message = {{"role", "user"}, {"content", "How are you?"}}; | |
| auto mock_task_controller1 = std::make_unique<MockTaskController>(); | |
| // Expect Cancel() to be called on the first task controller when | |
| // CancelGroup("group1") is called. | |
| EXPECT_CALL(*mock_task_controller1, Cancel()) | |
| .WillOnce(testing::Return(absl::OkStatus())); | |
| auto mock_task_controller2 = std::make_unique<MockTaskController>(); | |
| // Expect Cancel() to be called on the second task controller when | |
| // CancelGroup("group1") is called. | |
| EXPECT_CALL(*mock_task_controller2, Cancel()) | |
| .WillOnce(testing::Return(absl::OkStatus())); | |
| // Expect RunPrefillAsync to be called and return the first task controller. | |
| EXPECT_CALL(*mock_session_ptr, RunPrefillAsync(testing::_, testing::_)) | |
| .WillOnce([&](const std::vector<InputData>& contents, | |
| absl::AnyInvocable<void(absl::StatusOr<Responses>)> | |
| user_callback) { | |
| user_callback(Responses(TaskState::kDone)); | |
| return std::move(mock_task_controller1); | |
| }); | |
| // Expect RunDecodeAsync to be called and return the second task controller. | |
| EXPECT_CALL(*mock_session_ptr, RunDecodeAsync(testing::_, testing::_)) | |
| .WillOnce( | |
| [&](absl::AnyInvocable<void(absl::StatusOr<Responses>)> user_callback, | |
| const DecodeConfig& decode_config) { | |
| return std::move(mock_task_controller2); | |
| }); | |
| absl::Notification done; | |
| absl::Status status; | |
| EXPECT_OK( | |
| conversation->SendMessageAsync(user_message, | |
| [&](absl::StatusOr<Message> message) { | |
| status = message.status(); | |
| done.Notify(); | |
| }, | |
| {.task_group_id = "group1"})); | |
| conversation->CancelGroup("group1"); | |
| } | |
| TEST_P(ConversationTest, CancelGroupWithRunTextScoringAsync) { | |
| // Set up mock Session. | |
| auto mock_session = CreateMockSession(); | |
| MockSession* mock_session_ptr = mock_session.get(); | |
| auto cloned_session = std::make_unique<MockSession>(); | |
| // Expect GetSessionConfig to be called on the cloned session. | |
| MockSession* cloned_session_ptr = cloned_session.get(); | |
| EXPECT_CALL(*cloned_session_ptr, GetSessionConfig()) | |
| .WillRepeatedly(testing::ReturnRef(session_config_)); | |
| // Expect CloneAsync to be called and return the cloned session. | |
| EXPECT_CALL(*mock_session_ptr, CloneAsync(testing::_)) | |
| .WillOnce(testing::Return(std::move(cloned_session))); | |
| auto mock_engine = CreateMockEngine(std::move(mock_session)); | |
| EXPECT_CALL(*mock_engine, GetTokenizer()) | |
| .WillRepeatedly(testing::ReturnRef(*tokenizer_)); | |
| // Create Conversation. | |
| ASSERT_OK_AND_ASSIGN( | |
| auto conversation_config, | |
| ConversationConfig::Builder() | |
| .SetSessionConfig(session_config_) | |
| .SetOverwritePromptTemplate(PromptTemplate(kTestJinjaPromptTemplate)) | |
| .Build(*mock_engine)); | |
| ASSERT_OK_AND_ASSIGN(auto conversation, | |
| Conversation::Create(*mock_engine, conversation_config)); | |
| auto mock_task_controller = std::make_unique<MockTaskController>(); | |
| // Expect Cancel() to be called on the task controller when | |
| // CancelGroup("group1") is called. | |
| EXPECT_CALL(*mock_task_controller, Cancel()) | |
| .WillOnce(testing::Return(absl::OkStatus())); | |
| // Expect RunTextScoringAsync to be called on the cloned session and return | |
| // the task controller. | |
| EXPECT_CALL( | |
| *cloned_session_ptr, | |
| RunTextScoringAsync(testing::ElementsAre("I am good."), testing::_, true)) | |
| .WillOnce( | |
| [&](const std::vector<absl::string_view>& target_text, | |
| absl::AnyInvocable<void(absl::StatusOr<Responses>)> callback, | |
| bool store_token_lengths) { | |
| return std::move(mock_task_controller); | |
| }); | |
| absl::Notification done; | |
| std::string response_text; | |
| EXPECT_OK(conversation->RunTextScoringAsync( | |
| {"I am good."}, | |
| [&](absl::StatusOr<Responses> responses) { | |
| ASSERT_OK(responses); | |
| response_text = responses->GetTexts()[0]; | |
| done.Notify(); | |
| }, | |
| {.task_group_id = "group1"})); | |
| conversation->CancelGroup("group1"); | |
| } | |
| INSTANTIATE_TEST_SUITE_P( | |
| ConversationTest, ConversationTest, | |
| testing::ValuesIn(ConversationTest::GetTestParams()), | |
| [](const testing::TestParamInfo<ConversationTestParams>& info) { | |
| return absl::StrCat( | |
| info.param.enable_constrained_decoding ? "Constrained" : "Free", "_", | |
| info.param.prefill_preface_on_init ? "PrefillOnInit" | |
| : "NoPrefillOnInit"); | |
| }); | |
| absl::AnyInvocable<void(absl::StatusOr<Message>)> | |
| CreateCancelledMessageCallback(absl::Status& status, absl::Notification& done) { | |
| return [&status, &done](absl::StatusOr<Message> message) mutable { | |
| if (!message.ok()) { | |
| status = message.status(); | |
| done.Notify(); | |
| return; | |
| } | |
| if (auto json_message = std::get_if<JsonMessage>(&message.value()); | |
| json_message->is_null()) { | |
| status = absl::OkStatus(); | |
| done.Notify(); | |
| return; | |
| } | |
| // Wait for a short time to slow down the decoding process, so that the | |
| // cancellation can be triggered in the middle of decoding. | |
| absl::SleepFor(absl::Milliseconds(100)); | |
| }; | |
| } | |
| TEST(ConversationAccessHistoryTest, AccessHistory) { | |
| // Create a Conversation. | |
| ASSERT_OK_AND_ASSIGN(auto model_assets, | |
| ModelAssets::Create(GetTestdataPath(kTestLlmPath))); | |
| ASSERT_OK_AND_ASSIGN(auto engine_settings, EngineSettings::CreateDefault( | |
| model_assets, Backend::CPU)); | |
| engine_settings.GetMutableMainExecutorSettings().SetCacheDir(":nocache"); | |
| engine_settings.GetMutableMainExecutorSettings().SetMaxNumTokens(10); | |
| ASSERT_OK_AND_ASSIGN(auto engine, EngineFactory::CreateAny(engine_settings)); | |
| ASSERT_OK_AND_ASSIGN(auto config, ConversationConfig::CreateDefault(*engine)); | |
| ASSERT_OK_AND_ASSIGN(auto conversation, | |
| Conversation::Create(*engine, config)); | |
| // Send a message to the LLM. | |
| JsonMessage user_message = {{"role", "user"}, {"content", "Hello world!"}}; | |
| Message expected_assistant_message = | |
| JsonMessage({{"role", "assistant"}, | |
| {"content", | |
| {{{"type", "text"}, | |
| {"text", "TarefaByte دارایेत्र investigaciónప్రదేశ"}}}}}); | |
| Message expected_assistant_message_for_confirm = expected_assistant_message; | |
| absl::Notification done; | |
| EXPECT_OK(conversation->SendMessageAsync( | |
| user_message, | |
| CreateTestMessageCallback(expected_assistant_message, done))); | |
| done.WaitForNotificationWithTimeout(absl::Seconds(10)); | |
| // Get the history copy. | |
| auto history = conversation->GetHistory(); | |
| ASSERT_THAT(history.size(), 2); | |
| ASSERT_THAT(history.back(), | |
| testing::VariantWith<JsonMessage>(std::get<JsonMessage>( | |
| expected_assistant_message_for_confirm))); | |
| // Access the history with visitor function, and copy the last message. | |
| Message last_message; | |
| conversation->AccessHistory( | |
| [&last_message](const std::vector<Message>& history_view) { | |
| // Copy the last message to last_message. So we don't need to | |
| // copy the whole history, if we only need the last message. | |
| last_message = history_view.back(); | |
| }); | |
| EXPECT_THAT(last_message, | |
| testing::VariantWith<JsonMessage>(std::get<JsonMessage>( | |
| expected_assistant_message_for_confirm))); | |
| } | |
| class ConversationCancellationTest : public testing::TestWithParam<bool> { | |
| protected: | |
| bool use_benchmark_info_ = GetParam(); | |
| }; | |
| TEST_P(ConversationCancellationTest, CancelProcessWithBenchmarkInfo) { | |
| bool use_benchmark_info = use_benchmark_info_; | |
| ASSERT_OK_AND_ASSIGN(auto model_assets, | |
| ModelAssets::Create(GetTestdataPath(kTestLlmPath))); | |
| ASSERT_OK_AND_ASSIGN(auto engine_settings, EngineSettings::CreateDefault( | |
| model_assets, Backend::CPU)); | |
| engine_settings.GetMutableMainExecutorSettings().SetCacheDir(":nocache"); | |
| // Set a large max num tokens to ensure the decoding is not finished before | |
| // cancellation. | |
| engine_settings.GetMutableMainExecutorSettings().SetMaxNumTokens(20); | |
| if (use_benchmark_info) { | |
| proto::BenchmarkParams benchmark_params; | |
| engine_settings.GetMutableBenchmarkParams() = benchmark_params; | |
| } | |
| ASSERT_OK_AND_ASSIGN(auto engine, EngineFactory::CreateAny(engine_settings)); | |
| ASSERT_OK_AND_ASSIGN(auto config, ConversationConfig::CreateDefault(*engine)); | |
| ASSERT_OK_AND_ASSIGN(auto conversation, | |
| Conversation::Create(*engine, config)); | |
| absl::Status status; | |
| absl::Notification done_1; | |
| conversation | |
| ->SendMessageAsync( | |
| JsonMessage{{"role", "user"}, {"content", "Hello world!"}}, | |
| CreateCancelledMessageCallback(status, done_1)) | |
| .IgnoreError(); | |
| // Wait for a short time to ensure the decoding has started. | |
| absl::SleepFor(absl::Milliseconds(100)); | |
| conversation->CancelProcess(); | |
| // Wait for the callback to be done. | |
| done_1.WaitForNotificationWithTimeout(absl::Seconds(10)); | |
| EXPECT_THAT(status, testing::status::StatusIs(absl::StatusCode::kCancelled)); | |
| // The history should be empty after cancellation. | |
| EXPECT_THAT(conversation->GetHistory().size(), 0); | |
| // Resend the message after cancellation, and it should succeed. | |
| status = absl::OkStatus(); | |
| absl::Notification done_2; | |
| conversation | |
| ->SendMessageAsync( | |
| JsonMessage{{"role", "user"}, {"content", "Hello world!"}}, | |
| CreateCancelledMessageCallback(status, done_2)) | |
| .IgnoreError(); | |
| EXPECT_OK(status); | |
| // Wait for the callback to be done. | |
| done_2.WaitForNotificationWithTimeout(absl::Seconds(10)); | |
| // Without cancellation, the history should have two messages, user and | |
| // assistant. | |
| auto history = conversation->GetHistory(); | |
| ASSERT_EQ(history.size(), 2); | |
| EXPECT_THAT(history[0], testing::VariantWith<JsonMessage>(JsonMessage{ | |
| {"role", "user"}, {"content", "Hello world!"}})); | |
| // TODO(b/450903294) - Because the cancellation is not fully rollbacked, the | |
| // assistant message content depends on at which step the cancellation is | |
| // triggered, and that is non-deterministic. Here we only check the role is | |
| // assistant. | |
| EXPECT_THAT(std::holds_alternative<JsonMessage>(history[1]), | |
| testing::IsTrue()); | |
| EXPECT_EQ(std::get<JsonMessage>(history[1])["role"], "assistant"); | |
| conversation->CancelProcess(); | |
| // No op after cancellation again. | |
| EXPECT_THAT(conversation->GetHistory().size(), 2); | |
| } | |
| INSTANTIATE_TEST_SUITE_P(ConversationCancellationTest, | |
| ConversationCancellationTest, testing::Bool(), | |
| testing::PrintToStringParamName()); | |
| class MockConstraint : public Constraint { | |
| public: | |
| class MockState : public State { | |
| public: | |
| ~MockState() override = default; | |
| }; | |
| MOCK_METHOD(std::unique_ptr<State>, Start, (), (const, override)); | |
| MOCK_METHOD(bool, IsEnded, (const State& state), (const, override)); | |
| MOCK_METHOD(int, GetVocabularySize, (), (const, override)); | |
| MOCK_METHOD(absl::StatusOr<std::unique_ptr<State>>, ComputeNext, | |
| (const State& state, int token), (const, override)); | |
| MOCK_METHOD(absl::StatusOr<std::unique_ptr<Bitmap>>, ComputeBitmap, | |
| (const State& state), (const, override)); | |
| }; | |
| TEST_P(ConversationTest, SendMessageWithConstraint) { | |
| // Set up mock Session. | |
| auto mock_session = CreateMockSession(); | |
| MockSession* mock_session_ptr = mock_session.get(); | |
| auto mock_engine = CreateMockEngine(std::move(mock_session)); | |
| // Create Conversation with ExternalConstraintConfig. | |
| ASSERT_OK_AND_ASSIGN( | |
| auto conversation_config, | |
| ConversationConfig::Builder() | |
| .SetSessionConfig(session_config_) | |
| .SetOverwritePromptTemplate(PromptTemplate(kTestJinjaPromptTemplate)) | |
| .SetConstraintProviderConfig(ExternalConstraintConfig()) | |
| .Build(*mock_engine)); | |
| ASSERT_OK_AND_ASSIGN(auto conversation, | |
| Conversation::Create(*mock_engine, conversation_config)); | |
| // Create a mock constraint. | |
| auto mock_constraint = std::make_unique<MockConstraint>(); | |
| Constraint* mock_constraint_ptr = mock_constraint.get(); | |
| ExternalConstraintArg constraint_arg; | |
| constraint_arg.constraint = std::move(mock_constraint); | |
| // Send a message with the constraint. | |
| JsonMessage user_message = {{"role", "user"}, {"content", "How are you?"}}; | |
| EXPECT_CALL(*mock_session_ptr, RunPrefill(testing::_)) | |
| .WillOnce(testing::Return(absl::OkStatus())); | |
| // Verify that the constraint is passed to RunDecode. | |
| EXPECT_CALL(*mock_session_ptr, | |
| RunDecode(testing::Property(&DecodeConfig::GetConstraint, | |
| mock_constraint_ptr))) | |
| .WillOnce( | |
| testing::Return(Responses(TaskState::kProcessing, {"I am good."}))); | |
| ASSERT_OK_AND_ASSIGN( | |
| const Message response, | |
| conversation->SendMessage( | |
| user_message, { | |
| .decoding_constraint = std::move(constraint_arg), | |
| })); | |
| } | |
| TEST_P(ConversationTest, Clone) { | |
| // Set up mock Session. | |
| auto mock_session = CreateMockSession(); | |
| MockSession* mock_session_ptr = mock_session.get(); | |
| auto mock_engine = CreateMockEngine(std::move(mock_session)); | |
| EXPECT_CALL(*mock_engine, GetTokenizer()) | |
| .WillRepeatedly(testing::ReturnRef(*tokenizer_)); | |
| // Create Conversation. | |
| ASSERT_OK_AND_ASSIGN( | |
| auto conversation_config, | |
| ConversationConfig::Builder() | |
| .SetSessionConfig(session_config_) | |
| .SetOverwritePromptTemplate(PromptTemplate(kTestJinjaPromptTemplate)) | |
| .SetEnableConstrainedDecoding(enable_constrained_decoding_) | |
| .SetPrefillPrefaceOnInit(prefill_preface_on_init_) | |
| .Build(*mock_engine)); | |
| ASSERT_OK_AND_ASSIGN(auto conversation, | |
| Conversation::Create(*mock_engine, conversation_config)); | |
| // Send a message to populate history. | |
| JsonMessage user_message = {{"role", "user"}, {"content", "Hello"}}; | |
| EXPECT_CALL(*mock_session_ptr, RunPrefill(testing::_)) | |
| .WillOnce(testing::Return(absl::OkStatus())); | |
| EXPECT_CALL(*mock_session_ptr, RunDecode(testing::_)) | |
| .WillOnce(testing::Return(Responses(TaskState::kProcessing, {"Hi"}))); | |
| ASSERT_OK(conversation->SendMessage(user_message)); | |
| // Expect Session::Clone to be called. | |
| auto cloned_mock_session = std::make_unique<MockSession>(); | |
| MockSession* cloned_mock_session_ptr = cloned_mock_session.get(); | |
| EXPECT_CALL(*cloned_mock_session_ptr, GetSessionConfig()) | |
| .WillRepeatedly(testing::ReturnRef(session_config_)); | |
| EXPECT_CALL(*mock_session_ptr, Clone()) | |
| .WillOnce(testing::Return(std::move(cloned_mock_session))); | |
| // Clone the conversation. | |
| ASSERT_OK_AND_ASSIGN(auto cloned_conversation, conversation->Clone()); | |
| // Verify the history in the cloned conversation. | |
| auto history = cloned_conversation->GetHistory(); | |
| EXPECT_EQ(history.size(), 2); | |
| EXPECT_EQ(std::get<JsonMessage>(history[0]), user_message); | |
| // Verify that sending a message in the cloned conversation works and uses the | |
| // cloned session. | |
| JsonMessage user_message2 = {{"role", "user"}, {"content", "How are you?"}}; | |
| EXPECT_CALL(*cloned_mock_session_ptr, RunPrefill(testing::_)) | |
| .WillOnce(testing::Return(absl::OkStatus())); | |
| EXPECT_CALL(*cloned_mock_session_ptr, RunDecode(testing::_)) | |
| .WillOnce( | |
| testing::Return(Responses(TaskState::kProcessing, {"I am good."}))); | |
| ASSERT_OK(cloned_conversation->SendMessage(user_message2)); | |
| // Verify that the original conversation is unaffected by the new message in | |
| // the cloned one. | |
| EXPECT_EQ(conversation->GetHistory().size(), 2); | |
| EXPECT_EQ(cloned_conversation->GetHistory().size(), 4); | |
| } | |
| TEST_P(ConversationTest, SendMessageWithMaxOutputTokens) { | |
| // Set up mock Session. | |
| auto mock_session = std::make_unique<MockSession>(); | |
| MockSession* mock_session_ptr = mock_session.get(); | |
| SessionConfig session_config = SessionConfig::CreateDefault(); | |
| session_config.SetStartTokenId(0); | |
| session_config.GetMutableStopTokenIds().push_back({1}); | |
| *session_config.GetMutableLlmModelType().mutable_gemma3() = {}; | |
| EXPECT_CALL(*mock_session_ptr, GetSessionConfig()) | |
| .WillRepeatedly(testing::ReturnRef(session_config)); | |
| // Set up mock Engine. | |
| auto mock_engine = std::make_unique<MockEngine>(); | |
| EXPECT_CALL(*mock_engine, CreateSession(testing::_)) | |
| .WillOnce(testing::Return(std::move(mock_session))); | |
| EXPECT_CALL(*mock_engine, GetTokenizer()) | |
| .WillRepeatedly(testing::ReturnRef(*tokenizer_)); | |
| ASSERT_OK_AND_ASSIGN(auto model_assets, | |
| ModelAssets::Create(GetTestdataPath(kTestLlmPath))); | |
| ASSERT_OK_AND_ASSIGN(auto engine_settings, EngineSettings::CreateDefault( | |
| model_assets, Backend::CPU)); | |
| EXPECT_CALL(*mock_engine, GetEngineSettings()) | |
| .WillRepeatedly(testing::ReturnRef(engine_settings)); | |
| // Create Conversation with default config. | |
| ASSERT_OK_AND_ASSIGN( | |
| auto conversation_config, | |
| ConversationConfig::Builder() | |
| .SetSessionConfig(session_config) | |
| .SetOverwritePromptTemplate(PromptTemplate(kTestJinjaPromptTemplate)) | |
| .Build(*mock_engine)); | |
| ASSERT_OK_AND_ASSIGN(auto conversation, | |
| Conversation::Create(*mock_engine, conversation_config)); | |
| JsonMessage user_message = {{"role", "user"}, {"content", "How are you?"}}; | |
| EXPECT_CALL(*mock_session_ptr, RunPrefill(testing::_)) | |
| .WillOnce(testing::Return(absl::OkStatus())); | |
| // Verify that the max_output_tokens is passed to RunDecode. | |
| EXPECT_CALL(*mock_session_ptr, | |
| RunDecode(testing::Property(&DecodeConfig::GetMaxOutputTokens, | |
| std::make_optional(42)))) | |
| .WillOnce( | |
| testing::Return(Responses(TaskState::kProcessing, {"I am good."}))); | |
| ASSERT_OK_AND_ASSIGN( | |
| const Message response, | |
| conversation->SendMessage(user_message, {.max_output_tokens = 42})); | |
| } | |
| TEST(AppendMessageTest, Gemma3Sync) { | |
| // Set up mock Session. | |
| auto mock_session = std::make_unique<MockSession>(); | |
| MockSession* mock_session_ptr = mock_session.get(); | |
| SessionConfig session_config = SessionConfig::CreateDefault(); | |
| session_config.SetStartTokenId(0); | |
| session_config.GetMutableStopTokenIds().push_back({1}); | |
| *session_config.GetMutableLlmModelType().mutable_gemma3() = {}; | |
| session_config.SetApplyPromptTemplateInSession(false); | |
| EXPECT_CALL(*mock_session_ptr, GetSessionConfig()) | |
| .WillRepeatedly(testing::ReturnRef(session_config)); | |
| ASSERT_OK_AND_ASSIGN( | |
| auto tokenizer, | |
| SentencePieceTokenizer::CreateFromFile( | |
| (std::filesystem::path(::testing::SrcDir()) / kTestTokenizerPath) | |
| .string())); | |
| // Set up mock Engine. | |
| auto mock_engine = std::make_unique<MockEngine>(); | |
| EXPECT_CALL(*mock_engine, CreateSession(testing::_)) | |
| .WillOnce(testing::Return(std::move(mock_session))); | |
| EXPECT_CALL(*mock_engine, GetTokenizer()) | |
| .WillRepeatedly(testing::ReturnRef(*tokenizer)); | |
| ASSERT_OK_AND_ASSIGN(auto model_assets, | |
| ModelAssets::Create(GetTestdataPath(kTestLlmPath))); | |
| ASSERT_OK_AND_ASSIGN(auto engine_settings, EngineSettings::CreateDefault( | |
| model_assets, Backend::CPU)); | |
| EXPECT_CALL(*mock_engine, GetEngineSettings()) | |
| .WillRepeatedly(testing::ReturnRef(engine_settings)); | |
| std::string template_text = | |
| ReadFile(GetTestdataPath(kGemma3ToolsMultiPrefillTemplatePath)); | |
| // Create Conversation. | |
| ASSERT_OK_AND_ASSIGN( | |
| auto conversation_config, | |
| ConversationConfig::Builder() | |
| .SetSessionConfig(session_config) | |
| .SetOverwritePromptTemplate(PromptTemplate(template_text)) | |
| .SetPreface(JsonPreface{ | |
| .messages = {{{"role", "system"}, | |
| {"content", "You are a helpful assistant."}}}}) | |
| .Build(*mock_engine)); | |
| ASSERT_OK_AND_ASSIGN(auto conversation, | |
| Conversation::Create(*mock_engine, conversation_config)); | |
| // Append the 1st message. | |
| absl::string_view expected_prefill_1 = | |
| "<start_of_turn>user\nYou are a helpful " | |
| "assistant.\n\n<end_of_turn>\n<start_of_turn>user\nHello world!"; | |
| EXPECT_CALL(*mock_session_ptr, | |
| RunPrefill(testing::ElementsAre( | |
| testing::VariantWith<InputText>(testing::Property( | |
| &InputText::GetRawTextString, expected_prefill_1))))) | |
| .Times(1) | |
| .WillOnce(testing::Return(absl::OkStatus())); | |
| ASSERT_OK(conversation->SendMessage( | |
| JsonMessage{{"role", "user"}, {"content", "Hello world!"}}, | |
| {.has_pending_message = true})); | |
| // Append the 2nd message. | |
| absl::string_view expected_prefill_2 = " This is a long message."; | |
| EXPECT_CALL(*mock_session_ptr, | |
| RunPrefill(testing::ElementsAre( | |
| testing::VariantWith<InputText>(testing::Property( | |
| &InputText::GetRawTextString, expected_prefill_2))))) | |
| .Times(1) | |
| .WillOnce(testing::Return(absl::OkStatus())); | |
| ASSERT_OK(conversation->SendMessage( | |
| JsonMessage{{"role", "user"}, {"content", " This is a long message."}}, | |
| {.has_pending_message = true})); | |
| // Append the 3rd message. | |
| absl::string_view expected_prefill_3 = " continuing..."; | |
| EXPECT_CALL(*mock_session_ptr, | |
| RunPrefill(testing::ElementsAre( | |
| testing::VariantWith<InputText>(testing::Property( | |
| &InputText::GetRawTextString, expected_prefill_3))))) | |
| .Times(1) | |
| .WillOnce(testing::Return(absl::OkStatus())); | |
| ASSERT_OK(conversation->SendMessage( | |
| JsonMessage{{"role", "user"}, {"content", " continuing..."}}, | |
| {.has_pending_message = true})); | |
| // Finish appending message. | |
| absl::string_view expected_prefill_4 = | |
| " The message is ended.<end_of_turn>\n<start_of_turn>model\n"; | |
| EXPECT_CALL(*mock_session_ptr, | |
| RunPrefill(testing::ElementsAre( | |
| testing::VariantWith<InputText>(testing::Property( | |
| &InputText::GetRawTextString, expected_prefill_4))))) | |
| .Times(1) | |
| .WillOnce(testing::Return(absl::OkStatus())); | |
| EXPECT_CALL(*mock_session_ptr, RunDecode(testing::_)) | |
| .WillOnce( | |
| testing::Return(Responses(TaskState::kProcessing, {"I am good."}))); | |
| ASSERT_OK_AND_ASSIGN( | |
| const Message response_appending, | |
| conversation->SendMessage(JsonMessage{ | |
| {"role", "user"}, {"content", " The message is ended."}})); | |
| } | |
| TEST(AppendMessageTest, Gemma3Async) { | |
| // Set up mock Session. | |
| auto mock_session = std::make_unique<MockSession>(); | |
| MockSession* mock_session_ptr = mock_session.get(); | |
| SessionConfig session_config = SessionConfig::CreateDefault(); | |
| session_config.SetStartTokenId(0); | |
| session_config.GetMutableStopTokenIds().push_back({1}); | |
| *session_config.GetMutableLlmModelType().mutable_gemma3() = {}; | |
| session_config.SetApplyPromptTemplateInSession(false); | |
| EXPECT_CALL(*mock_session_ptr, GetSessionConfig()) | |
| .WillRepeatedly(testing::ReturnRef(session_config)); | |
| ASSERT_OK_AND_ASSIGN( | |
| auto tokenizer, | |
| SentencePieceTokenizer::CreateFromFile( | |
| (std::filesystem::path(::testing::SrcDir()) / kTestTokenizerPath) | |
| .string())); | |
| // Set up mock Engine. | |
| auto mock_engine = std::make_unique<MockEngine>(); | |
| EXPECT_CALL(*mock_engine, CreateSession(testing::_)) | |
| .WillOnce(testing::Return(std::move(mock_session))); | |
| EXPECT_CALL(*mock_engine, GetTokenizer()) | |
| .WillRepeatedly(testing::ReturnRef(*tokenizer)); | |
| ASSERT_OK_AND_ASSIGN(auto model_assets, | |
| ModelAssets::Create(GetTestdataPath(kTestLlmPath))); | |
| ASSERT_OK_AND_ASSIGN(auto engine_settings, EngineSettings::CreateDefault( | |
| model_assets, Backend::CPU)); | |
| EXPECT_CALL(*mock_engine, GetEngineSettings()) | |
| .WillRepeatedly(testing::ReturnRef(engine_settings)); | |
| std::string template_text = | |
| ReadFile(GetTestdataPath(kGemma3ToolsMultiPrefillTemplatePath)); | |
| // Create Conversation. | |
| ASSERT_OK_AND_ASSIGN( | |
| auto conversation_config, | |
| ConversationConfig::Builder() | |
| .SetSessionConfig(session_config) | |
| .SetOverwritePromptTemplate(PromptTemplate(template_text)) | |
| .Build(*mock_engine)); | |
| ASSERT_OK_AND_ASSIGN(auto conversation, | |
| Conversation::Create(*mock_engine, conversation_config)); | |
| auto test_callback = | |
| [](const std::vector<InputData>& contents, | |
| absl::AnyInvocable<void(absl::StatusOr<Responses>)> user_callback) { | |
| user_callback(Responses(TaskState::kDone)); | |
| return nullptr; | |
| }; | |
| // Append the 1st message. | |
| absl::string_view expected_prefill_1 = "<start_of_turn>user\nHello world!"; | |
| EXPECT_CALL( | |
| *mock_session_ptr, | |
| RunPrefillAsync(testing::ElementsAre(testing::VariantWith<InputText>( | |
| testing::Property(&InputText::GetRawTextString, | |
| expected_prefill_1))), | |
| testing::_)) | |
| .Times(1) | |
| .WillOnce(test_callback); | |
| absl::Notification done1; | |
| ASSERT_OK(conversation->SendMessageAsync( | |
| JsonMessage{{"role", "user"}, {"content", "Hello world!"}}, | |
| [&done1](absl::StatusOr<Message> message) { done1.Notify(); }, | |
| {.has_pending_message = true})); | |
| done1.WaitForNotificationWithTimeout(absl::Seconds(3)); | |
| // Append the 2nd message. | |
| absl::string_view expected_prefill_2 = " This is a long message."; | |
| EXPECT_CALL( | |
| *mock_session_ptr, | |
| RunPrefillAsync(testing::ElementsAre(testing::VariantWith<InputText>( | |
| testing::Property(&InputText::GetRawTextString, | |
| expected_prefill_2))), | |
| testing::_)) | |
| .Times(1) | |
| .WillOnce(test_callback); | |
| absl::Notification done2; | |
| ASSERT_OK(conversation->SendMessageAsync( | |
| JsonMessage{{"role", "user"}, {"content", " This is a long message."}}, | |
| [&done2](absl::StatusOr<Message> message) { done2.Notify(); }, | |
| {.has_pending_message = true})); | |
| done2.WaitForNotificationWithTimeout(absl::Seconds(3)); | |
| // Append the 3rd message. | |
| absl::string_view expected_prefill_3 = " continuing..."; | |
| EXPECT_CALL( | |
| *mock_session_ptr, | |
| RunPrefillAsync(testing::ElementsAre(testing::VariantWith<InputText>( | |
| testing::Property(&InputText::GetRawTextString, | |
| expected_prefill_3))), | |
| testing::_)) | |
| .Times(1) | |
| .WillOnce(test_callback); | |
| absl::Notification done3; | |
| ASSERT_OK(conversation->SendMessageAsync( | |
| JsonMessage{{"role", "user"}, {"content", " continuing..."}}, | |
| [&done3](absl::StatusOr<Message> message) { done3.Notify(); }, | |
| {.has_pending_message = true})); | |
| done3.WaitForNotificationWithTimeout(absl::Seconds(3)); | |
| // Append the 4th message. | |
| absl::string_view expected_prefill_4 = " The message is ended."; | |
| EXPECT_CALL( | |
| *mock_session_ptr, | |
| RunPrefillAsync(testing::ElementsAre(testing::VariantWith<InputText>( | |
| testing::Property(&InputText::GetRawTextString, | |
| expected_prefill_4))), | |
| testing::_)) | |
| .Times(1) | |
| .WillOnce(test_callback); | |
| absl::Notification done4; | |
| EXPECT_OK(conversation->SendMessageAsync( | |
| JsonMessage{{"role", "user"}, {"content", " The message is ended."}}, | |
| [&done4](absl::StatusOr<Message> message) { done4.Notify(); }, | |
| {.has_pending_message = true})); | |
| done4.WaitForNotificationWithTimeout(absl::Seconds(3)); | |
| // The 5th message triggers the decode. | |
| absl::string_view expected_prefill_5 = | |
| "<end_of_turn>\n<start_of_turn>model\n"; | |
| EXPECT_CALL( | |
| *mock_session_ptr, | |
| RunPrefillAsync(testing::ElementsAre(testing::VariantWith<InputText>( | |
| testing::Property(&InputText::GetRawTextString, | |
| expected_prefill_5))), | |
| testing::_)) | |
| .Times(1) | |
| .WillOnce(test_callback); | |
| EXPECT_CALL(*mock_session_ptr, RunDecodeAsync(testing::_, testing::_)) | |
| .WillOnce( | |
| [](absl::AnyInvocable<void(absl::StatusOr<Responses>)> user_callback, | |
| const DecodeConfig& decode_config) { | |
| user_callback(Responses(TaskState::kProcessing, {"I am good."})); | |
| user_callback(Responses(TaskState::kDone)); | |
| return nullptr; | |
| }); | |
| Message expected_assistant_message = | |
| JsonMessage({{"role", "assistant"}, | |
| {"content", {{{"type", "text"}, {"text", "I am good."}}}}}); | |
| absl::Notification done5; | |
| // Trigger the decode by sending an empty message. | |
| EXPECT_OK(conversation->SendMessageAsync( | |
| JsonMessage{{"role", "user"}, {"content", ""}}, | |
| CreateTestMessageCallback(expected_assistant_message, done5), | |
| {.has_pending_message = false})); | |
| done5.WaitForNotificationWithTimeout(absl::Seconds(3)); | |
| } | |
| TEST(AppendMessageTest, Gemma3SyncPrefillPrefaceOnInitAndAlternateRoles) { | |
| // Set up mock Session. | |
| auto mock_session = std::make_unique<MockSession>(); | |
| MockSession* mock_session_ptr = mock_session.get(); | |
| SessionConfig session_config = SessionConfig::CreateDefault(); | |
| session_config.SetStartTokenId(0); | |
| session_config.GetMutableStopTokenIds().push_back({1}); | |
| *session_config.GetMutableLlmModelType().mutable_gemma3() = {}; | |
| session_config.SetApplyPromptTemplateInSession(false); | |
| EXPECT_CALL(*mock_session_ptr, GetSessionConfig()) | |
| .WillRepeatedly(testing::ReturnRef(session_config)); | |
| ASSERT_OK_AND_ASSIGN( | |
| auto tokenizer, | |
| SentencePieceTokenizer::CreateFromFile( | |
| (std::filesystem::path(::testing::SrcDir()) / kTestTokenizerPath) | |
| .string())); | |
| // Set up mock Engine. | |
| auto mock_engine = std::make_unique<MockEngine>(); | |
| EXPECT_CALL(*mock_engine, CreateSession(testing::_)) | |
| .WillOnce(testing::Return(std::move(mock_session))); | |
| EXPECT_CALL(*mock_engine, GetTokenizer()) | |
| .WillRepeatedly(testing::ReturnRef(*tokenizer)); | |
| ASSERT_OK_AND_ASSIGN(auto model_assets, | |
| ModelAssets::Create(GetTestdataPath(kTestLlmPath))); | |
| ASSERT_OK_AND_ASSIGN(auto engine_settings, EngineSettings::CreateDefault( | |
| model_assets, Backend::CPU)); | |
| EXPECT_CALL(*mock_engine, GetEngineSettings()) | |
| .WillRepeatedly(testing::ReturnRef(engine_settings)); | |
| std::string template_text = | |
| ReadFile(GetTestdataPath(kGemma3ToolsMultiPrefillTemplatePath)); | |
| // Init with preface. | |
| absl::string_view expected_prefill_preface = R"(<start_of_turn>system | |
| def tool_name( | |
| x: int | None = None, | |
| ) -> dict: | |
| """ | |
| Args: | |
| x """ | |
| <end_of_turn> | |
| <start_of_turn>user | |
| You are a helpful assistant. | |
| <end_of_turn> | |
| )"; | |
| EXPECT_CALL(*mock_session_ptr, | |
| RunPrefill(testing::ElementsAre(testing::VariantWith<InputText>( | |
| testing::Property(&InputText::GetRawTextString, | |
| expected_prefill_preface))))) | |
| .Times(1) | |
| .WillOnce(testing::Return(absl::OkStatus())); | |
| // Create Conversation. | |
| ASSERT_OK_AND_ASSIGN( | |
| auto conversation_config, | |
| ConversationConfig::Builder() | |
| .SetSessionConfig(session_config) | |
| .SetOverwritePromptTemplate(PromptTemplate(template_text)) | |
| .SetPreface(JsonPreface{ | |
| .messages = {{{"role", "system"}, | |
| {"content", "You are a helpful assistant."}}}, | |
| .tools = nlohmann::ordered_json::parse( | |
| R"json([{ | |
| "name": "tool_name", | |
| "parameters": { | |
| "properties": { | |
| "x": { | |
| "type": "integer" | |
| } | |
| } | |
| } | |
| }])json")}) | |
| .SetPrefillPrefaceOnInit(true) | |
| .Build(*mock_engine)); | |
| ASSERT_OK_AND_ASSIGN(auto conversation, | |
| Conversation::Create(*mock_engine, conversation_config)); | |
| // Append the 1st message. | |
| absl::string_view expected_prefill_1 = "<start_of_turn>user\nHello world!"; | |
| EXPECT_CALL(*mock_session_ptr, | |
| RunPrefill(testing::ElementsAre( | |
| testing::VariantWith<InputText>(testing::Property( | |
| &InputText::GetRawTextString, expected_prefill_1))))) | |
| .Times(1) | |
| .WillOnce(testing::Return(absl::OkStatus())); | |
| ASSERT_OK(conversation->SendMessage( | |
| JsonMessage{{"role", "user"}, {"content", "Hello world!"}}, | |
| {.has_pending_message = true})); | |
| // Append the 2nd message. | |
| absl::string_view expected_prefill_2 = | |
| "<end_of_turn>\n<start_of_turn>model\nNice to meet you."; | |
| EXPECT_CALL(*mock_session_ptr, | |
| RunPrefill(testing::ElementsAre( | |
| testing::VariantWith<InputText>(testing::Property( | |
| &InputText::GetRawTextString, expected_prefill_2))))) | |
| .Times(1) | |
| .WillOnce(testing::Return(absl::OkStatus())); | |
| ASSERT_OK(conversation->SendMessage( | |
| JsonMessage{{"role", "model"}, {"content", "Nice to meet you."}}, | |
| {.has_pending_message = true})); | |
| // Append the 3rd message. | |
| absl::string_view expected_prefill_3 = " How can I help you today?"; | |
| EXPECT_CALL(*mock_session_ptr, | |
| RunPrefill(testing::ElementsAre( | |
| testing::VariantWith<InputText>(testing::Property( | |
| &InputText::GetRawTextString, expected_prefill_3))))) | |
| .Times(1) | |
| .WillOnce(testing::Return(absl::OkStatus())); | |
| ASSERT_OK(conversation->SendMessage( | |
| JsonMessage{{"role", "model"}, {"content", " How can I help you today?"}}, | |
| {.has_pending_message = true})); | |
| // Append the 4th message. | |
| absl::string_view expected_prefill_4 = " The message is ended."; | |
| EXPECT_CALL(*mock_session_ptr, | |
| RunPrefill(testing::ElementsAre( | |
| testing::VariantWith<InputText>(testing::Property( | |
| &InputText::GetRawTextString, expected_prefill_4))))) | |
| .Times(1) | |
| .WillOnce(testing::Return(absl::OkStatus())); | |
| ASSERT_OK(conversation->SendMessage( | |
| JsonMessage{{"role", "model"}, {"content", " The message is ended."}}, | |
| {.has_pending_message = true})); | |
| // Append the 5th message. | |
| absl::string_view expected_prefill_5 = R"(<end_of_turn> | |
| <start_of_turn>user | |
| ```tool_outputs | |
| {"location": "Paris", "temperature": 20, "unit": "C", "weather": "Sunny"})"; | |
| EXPECT_CALL(*mock_session_ptr, | |
| RunPrefill(testing::ElementsAre( | |
| testing::VariantWith<InputText>(testing::Property( | |
| &InputText::GetRawTextString, expected_prefill_5))))) | |
| .Times(1) | |
| .WillOnce(testing::Return(absl::OkStatus())); | |
| ASSERT_OK( | |
| conversation->SendMessage(JsonMessage{{"role", "tool"}, | |
| {"content", | |
| { | |
| {"type", "tool_response"}, | |
| {"tool_response", | |
| { | |
| {"location", "Paris"}, | |
| {"temperature", 20}, | |
| {"unit", "C"}, | |
| {"weather", "Sunny"}, | |
| }}, | |
| }}}, | |
| {.has_pending_message = true})); | |
| // Append the 6th message. | |
| absl::string_view expected_prefill_6 = | |
| R"({"location": "London", "temperature": 15, "unit": "C", "weather": "Cloudy"} | |
| ```<end_of_turn> | |
| <start_of_turn>model | |
| )"; | |
| EXPECT_CALL(*mock_session_ptr, | |
| RunPrefill(testing::ElementsAre( | |
| testing::VariantWith<InputText>(testing::Property( | |
| &InputText::GetRawTextString, expected_prefill_6))))) | |
| .Times(1) | |
| .WillOnce(testing::Return(absl::OkStatus())); | |
| EXPECT_CALL(*mock_session_ptr, RunDecode(testing::_)) | |
| .WillOnce( | |
| testing::Return(Responses(TaskState::kProcessing, {"I am good."}))); | |
| ASSERT_OK( | |
| conversation->SendMessage(JsonMessage{{"role", "tool"}, | |
| {"content", | |
| { | |
| {"type", "tool_response"}, | |
| {"tool_response", | |
| { | |
| {"location", "London"}, | |
| {"temperature", 15}, | |
| {"unit", "C"}, | |
| {"weather", "Cloudy"}, | |
| }}, | |
| }}}, | |
| {.has_pending_message = false})); | |
| } | |
| } // namespace | |
| } // namespace litert::lm | |