// Copyright 2025 The ODML Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "schema/core/litertlm_section.h" #include #include #include // NOLINT: Required for path manipulation. #include #include #include #include #include #include #include #include "absl/status/status.h" // from @com_google_absl #include "runtime/proto/llm_metadata.pb.h" #include "runtime/proto/token.pb.h" namespace litert::lm::schema { namespace { TEST(LiteRTLMSectionTest, TestFileBackedSectionStream) { // Get the file path const auto file_path = std::filesystem::path(::testing::SrcDir()) / "litert_lm/schema/testdata/attention.tflite"; const auto output_file_path = std::filesystem::path(::testing::TempDir()) / "attention_copy.tflite"; // Define an output file path std::ofstream output_file(output_file_path, std::ios::binary); EXPECT_TRUE(output_file.is_open()); // Create the file-backed Section stream object FileBackedSectionStream fbss(file_path.string()); absl::Status result = fbss.Prepare(); ASSERT_TRUE(result.ok()); size_t fbss_size = fbss.BufferSize(); auto& fbss_stream = fbss.GetStream(); output_file << fbss_stream.rdbuf(); EXPECT_EQ(output_file.tellp(), fbss_size); output_file.close(); // Read the file back and check contents. std::ifstream input_file(output_file_path, std::ios::binary); ASSERT_TRUE(input_file.is_open()); // Read the original file into a buffer std::ifstream original_file(file_path, std::ios::binary | std::ios::ate); ASSERT_TRUE(original_file.is_open()); size_t original_size = original_file.tellg(); original_file.seekg(0, std::ios::beg); std::vector original_buffer(original_size); original_file.read(original_buffer.data(), original_size); original_file.close(); // Read the copied file into a buffer std::vector copied_buffer(fbss_size); input_file.read(copied_buffer.data(), fbss_size); input_file.close(); // Compare the buffers ASSERT_EQ(original_size, fbss_size); EXPECT_TRUE(std::equal(original_buffer.begin(), original_buffer.end(), copied_buffer.begin(), copied_buffer.end())); } TEST(LiteRTLMSectionTest, TestProtoSectionStream) { using litert::lm::proto::LlmMetadata; // Constants for the Token Generation Data const std::string start_token = ""; const std::vector stop_tokens = {"", ""}; const auto output_file_path = std::filesystem::path(::testing::TempDir()) / "llm_metadata.pb"; // Create an LlmMetadata protocol buffer LlmMetadata metadata; // Set the start_token metadata.mutable_start_token()->set_token_str(start_token); // Set the stop_tokens for (const std::string& stop_token : stop_tokens) { metadata.add_stop_tokens()->set_token_str(stop_token); } // ** Write the file using typical protobuf serialization ** std::string serialized_params = metadata.SerializeAsString(); // Convert the serialized string to a vector of unsigned chars std::vector buffer(serialized_params.begin(), serialized_params.end()); // Write the buffer to a file std::ofstream output_file(output_file_path, std::ios::binary); ASSERT_TRUE(output_file.is_open()); output_file.write(reinterpret_cast(buffer.data()), buffer.size()); std::ofstream::pos_type bytes_written = output_file.tellp(); ASSERT_GT(bytes_written, 0); output_file.close(); // ** Write the file using SectionStream interface ** ProtoBufSectionStream pbss(metadata); absl::Status result = pbss.Prepare(); ASSERT_TRUE(result.ok()); size_t pbss_size = pbss.BufferSize(); auto& pbss_stream = pbss.GetStream(); const auto output_file_streamed_path = std::filesystem::path(::testing::TempDir()) / "llm_metadata_streamed.pb"; std::ofstream output_streamed(output_file_streamed_path, std::ios::binary); ASSERT_TRUE(output_streamed.is_open()); output_streamed << pbss_stream.rdbuf(); EXPECT_EQ(output_streamed.tellp(), pbss_size); output_streamed.close(); // ** Read the file back in and check the contents ** std::ifstream input_streamed(output_file_streamed_path, std::ios::binary); ASSERT_TRUE(input_streamed.is_open()); std::stringstream ss; ss << input_streamed.rdbuf(); // Read the entire file into a stringstream std::string serialized_read_back = ss.str(); input_streamed.close(); LlmMetadata params_read_back; ASSERT_TRUE(params_read_back.ParseFromString(serialized_read_back)); // Compare the fields. EXPECT_EQ(metadata.start_token().token_str(), params_read_back.start_token().token_str()); EXPECT_EQ(metadata.stop_tokens().size(), params_read_back.stop_tokens().size()); for (int i = 0; i < metadata.stop_tokens().size(); ++i) { EXPECT_EQ(metadata.stop_tokens(i).token_str(), params_read_back.stop_tokens(i).token_str()); } } TEST(LiteRTLMSectionTest, TestStringBackedSectionStream) { using litert::lm::schema::StringBackedSectionStream; // Use a string with embedded nulls to ensure it handles binary data // correctly. const std::string original_data("Test\0Binary\0Data", 16); // Use the SectionStream interface. StringBackedSectionStream stream(original_data); // Prepare the stream for reading. absl::Status result = stream.Prepare(); EXPECT_TRUE(result.ok()); // Get the stream's properties. const size_t stream_size = stream.BufferSize(); std::istream& data_stream = stream.GetStream(); // The reported buffer size should match the original data's size. EXPECT_EQ(stream_size, original_data.size()); // Read the entire contents of the stream back into a new string. std::stringstream buffer; buffer << data_stream.rdbuf(); const std::string data_read_back = buffer.str(); // Compare the data that was read from the stream to the original data. // They must be identical in both size and content. EXPECT_EQ(data_read_back.size(), original_data.size()); EXPECT_EQ(data_read_back, original_data); // Finally, test that the stream can be finalized correctly. result = stream.Finalize(); EXPECT_TRUE(result.ok()); EXPECT_FALSE(stream.IsReady()); } } // namespace } // namespace litert::lm::schema