LiteRT-LM / schema /core /litertlm_section_test.cc
SeaWolf-AI's picture
Upload full LiteRT-LM codebase
5f923cd verified
// Copyright 2025 The ODML Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "schema/core/litertlm_section.h"
#include <algorithm>
#include <cstddef>
#include <filesystem> // NOLINT: Required for path manipulation.
#include <fstream>
#include <ios>
#include <istream>
#include <sstream>
#include <string>
#include <vector>
#include <gtest/gtest.h>
#include "absl/status/status.h" // from @com_google_absl
#include "runtime/proto/llm_metadata.pb.h"
#include "runtime/proto/token.pb.h"
namespace litert::lm::schema {
namespace {
TEST(LiteRTLMSectionTest, TestFileBackedSectionStream) {
// Get the file path
const auto file_path =
std::filesystem::path(::testing::SrcDir()) /
"litert_lm/schema/testdata/attention.tflite";
const auto output_file_path =
std::filesystem::path(::testing::TempDir()) / "attention_copy.tflite";
// Define an output file path
std::ofstream output_file(output_file_path, std::ios::binary);
EXPECT_TRUE(output_file.is_open());
// Create the file-backed Section stream object
FileBackedSectionStream fbss(file_path.string());
absl::Status result = fbss.Prepare();
ASSERT_TRUE(result.ok());
size_t fbss_size = fbss.BufferSize();
auto& fbss_stream = fbss.GetStream();
output_file << fbss_stream.rdbuf();
EXPECT_EQ(output_file.tellp(), fbss_size);
output_file.close();
// Read the file back and check contents.
std::ifstream input_file(output_file_path, std::ios::binary);
ASSERT_TRUE(input_file.is_open());
// Read the original file into a buffer
std::ifstream original_file(file_path, std::ios::binary | std::ios::ate);
ASSERT_TRUE(original_file.is_open());
size_t original_size = original_file.tellg();
original_file.seekg(0, std::ios::beg);
std::vector<char> original_buffer(original_size);
original_file.read(original_buffer.data(), original_size);
original_file.close();
// Read the copied file into a buffer
std::vector<char> copied_buffer(fbss_size);
input_file.read(copied_buffer.data(), fbss_size);
input_file.close();
// Compare the buffers
ASSERT_EQ(original_size, fbss_size);
EXPECT_TRUE(std::equal(original_buffer.begin(), original_buffer.end(),
copied_buffer.begin(), copied_buffer.end()));
}
TEST(LiteRTLMSectionTest, TestProtoSectionStream) {
using litert::lm::proto::LlmMetadata;
// Constants for the Token Generation Data
const std::string start_token = "<start>";
const std::vector<std::string> stop_tokens = {"<stop>", "<eos>"};
const auto output_file_path =
std::filesystem::path(::testing::TempDir()) / "llm_metadata.pb";
// Create an LlmMetadata protocol buffer
LlmMetadata metadata;
// Set the start_token
metadata.mutable_start_token()->set_token_str(start_token);
// Set the stop_tokens
for (const std::string& stop_token : stop_tokens) {
metadata.add_stop_tokens()->set_token_str(stop_token);
}
// ** Write the file using typical protobuf serialization **
std::string serialized_params = metadata.SerializeAsString();
// Convert the serialized string to a vector of unsigned chars
std::vector<unsigned char> buffer(serialized_params.begin(),
serialized_params.end());
// Write the buffer to a file
std::ofstream output_file(output_file_path, std::ios::binary);
ASSERT_TRUE(output_file.is_open());
output_file.write(reinterpret_cast<const char*>(buffer.data()),
buffer.size());
std::ofstream::pos_type bytes_written = output_file.tellp();
ASSERT_GT(bytes_written, 0);
output_file.close();
// ** Write the file using SectionStream interface **
ProtoBufSectionStream<LlmMetadata> pbss(metadata);
absl::Status result = pbss.Prepare();
ASSERT_TRUE(result.ok());
size_t pbss_size = pbss.BufferSize();
auto& pbss_stream = pbss.GetStream();
const auto output_file_streamed_path =
std::filesystem::path(::testing::TempDir()) / "llm_metadata_streamed.pb";
std::ofstream output_streamed(output_file_streamed_path, std::ios::binary);
ASSERT_TRUE(output_streamed.is_open());
output_streamed << pbss_stream.rdbuf();
EXPECT_EQ(output_streamed.tellp(), pbss_size);
output_streamed.close();
// ** Read the file back in and check the contents **
std::ifstream input_streamed(output_file_streamed_path, std::ios::binary);
ASSERT_TRUE(input_streamed.is_open());
std::stringstream ss;
ss << input_streamed.rdbuf(); // Read the entire file into a stringstream
std::string serialized_read_back = ss.str();
input_streamed.close();
LlmMetadata params_read_back;
ASSERT_TRUE(params_read_back.ParseFromString(serialized_read_back));
// Compare the fields.
EXPECT_EQ(metadata.start_token().token_str(),
params_read_back.start_token().token_str());
EXPECT_EQ(metadata.stop_tokens().size(),
params_read_back.stop_tokens().size());
for (int i = 0; i < metadata.stop_tokens().size(); ++i) {
EXPECT_EQ(metadata.stop_tokens(i).token_str(),
params_read_back.stop_tokens(i).token_str());
}
}
TEST(LiteRTLMSectionTest, TestStringBackedSectionStream) {
using litert::lm::schema::StringBackedSectionStream;
// Use a string with embedded nulls to ensure it handles binary data
// correctly.
const std::string original_data("Test\0Binary\0Data", 16);
// Use the SectionStream interface.
StringBackedSectionStream stream(original_data);
// Prepare the stream for reading.
absl::Status result = stream.Prepare();
EXPECT_TRUE(result.ok());
// Get the stream's properties.
const size_t stream_size = stream.BufferSize();
std::istream& data_stream = stream.GetStream();
// The reported buffer size should match the original data's size.
EXPECT_EQ(stream_size, original_data.size());
// Read the entire contents of the stream back into a new string.
std::stringstream buffer;
buffer << data_stream.rdbuf();
const std::string data_read_back = buffer.str();
// Compare the data that was read from the stream to the original data.
// They must be identical in both size and content.
EXPECT_EQ(data_read_back.size(), original_data.size());
EXPECT_EQ(data_read_back, original_data);
// Finally, test that the stream can be finalized correctly.
result = stream.Finalize();
EXPECT_TRUE(result.ok());
EXPECT_FALSE(stream.IsReady());
}
} // namespace
} // namespace litert::lm::schema