// Copyright 2026 The ODML Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "runtime/conversation/channel_util.h" #include #include #include #include #include #include #include "absl/container/flat_hash_map.h" // from @com_google_absl #include "absl/status/status.h" // from @com_google_absl #include "absl/status/statusor.h" // from @com_google_absl #include "runtime/conversation/io_types.h" #include "runtime/engine/io_types.h" #include "runtime/util/test_utils.h" // IWYU pragma: keep namespace litert::lm { namespace { using ::testing::Eq; using ::testing::IsEmpty; using ::testing::UnorderedElementsAre; TEST(ExtractChannelTextTest, EmptyResponses) { Responses responses(TaskState::kDone); std::vector channels = {{"thought", "", ""}}; auto channel_content = ExtractChannelContent(channels, responses); ASSERT_OK(channel_content); EXPECT_THAT(*channel_content, IsEmpty()); EXPECT_THAT(responses.GetTexts(), IsEmpty()); } TEST(ExtractChannelTextTest, MultipleTextsError) { Responses responses(TaskState::kProcessing, {"Text 1", "Text 2"}); std::vector channels = {{"thought", "", ""}}; auto channel_content = ExtractChannelContent(channels, responses); EXPECT_FALSE(channel_content.ok()); EXPECT_THAT(channel_content.status().code(), Eq(absl::StatusCode::kInvalidArgument)); } TEST(ExtractChannelTextTest, SingleChannelOccurrence) { Responses responses(TaskState::kProcessing, {"Hello hmm World!"}); std::vector channels = {{"thought", "", ""}}; auto channel_content = ExtractChannelContent(channels, responses); ASSERT_OK(channel_content); EXPECT_THAT(*channel_content, UnorderedElementsAre( std::pair("thought", "hmm"))); EXPECT_THAT(responses.GetTexts()[0], Eq("Hello World!")); } TEST(ExtractChannelTextTest, MultipleOccurrencesOfSameChannel) { Responses responses(TaskState::kProcessing, {"Hello hmm World yeah!"}); std::vector channels = {{"thought", "", ""}}; auto channel_content = ExtractChannelContent(channels, responses); ASSERT_OK(channel_content); EXPECT_THAT(*channel_content, UnorderedElementsAre(std::pair( "thought", "hmmyeah"))); EXPECT_THAT(responses.GetTexts()[0], Eq("Hello World !")); } TEST(ExtractChannelTextTest, MultipleDifferentChannels) { Responses responses(TaskState::kProcessing, {"Hello hmm World lol!"}); std::vector channels = { {"thought", "", ""}, {"joke", "", ""}, }; auto channel_content = ExtractChannelContent(channels, responses); ASSERT_OK(channel_content); EXPECT_THAT(*channel_content, UnorderedElementsAre( std::pair("thought", "hmm"), std::pair("joke", "lol"))); EXPECT_THAT(responses.GetTexts()[0], Eq("Hello World !")); } TEST(ExtractChannelTextTest, NoChannelFound) { Responses responses(TaskState::kProcessing, {"Hello World!"}); std::vector channels = {{"thought", "", ""}}; auto channel_content = ExtractChannelContent(channels, responses); ASSERT_OK(channel_content); EXPECT_THAT(*channel_content, IsEmpty()); EXPECT_THAT(responses.GetTexts()[0], Eq("Hello World!")); } TEST(ExtractChannelTextTest, MissingEndDelimiter) { Responses responses(TaskState::kProcessing, {"Hello hmm"}); std::vector channels = {{"thought", "", ""}}; auto channel_content = ExtractChannelContent(channels, responses); ASSERT_OK(channel_content); EXPECT_THAT(*channel_content, UnorderedElementsAre( std::pair("thought", "hmm"))); EXPECT_THAT(responses.GetTexts()[0], Eq("Hello ")); } TEST(InsertChannelContentIntoMessageTest, JsonMessageInsertion) { JsonMessage json_msg = {{"role", "assistant"}, {"content", "Hello!"}}; Message message(json_msg); absl::flat_hash_map channel_content = { {"thought", "hmm"}}; InsertChannelContentIntoMessage(channel_content, message); auto* result_json = std::get_if(&message); ASSERT_NE(result_json, nullptr); EXPECT_THAT((*result_json)["channels"]["thought"], Eq("hmm")); } } // namespace } // namespace litert::lm