LiteRT-LM / schema /py /litertlm_builder_test.py
SeaWolf-AI's picture
Upload full LiteRT-LM codebase
5f923cd verified
# Copyright 2025 The ODML Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import io
import os
import pathlib
import zlib
from absl.testing import absltest
from absl.testing import parameterized
from google.protobuf import text_format
from litert_lm.runtime.proto import llm_metadata_pb2
from litert_lm.schema.py import litertlm_builder
from litert_lm.schema.py import litertlm_core
from litert_lm.schema.py import litertlm_peek
_TOML_TEMPLATE = """
# A template for testing the TOML parser.
[system_metadata]
entries = [
{ key = "author", value_type = "String", value = "The ODML Authors" }
]
[[section]]
# Section 0: LlmMetadataProto
section_type = "LlmMetadata"
data_path = "{LLM_METADATA_PATH}"
[[section]]
# Section 1: SP_Tokenizer
section_type = "SP_Tokenizer"
data_path = "{SP_TOKENIZER_PATH}"
[[section]]
# Section 2: TFLiteModel (Embedder)
section_type = "TFLiteModel"
model_type = "EMBEDDER"
data_path = "{EMBEDDER_PATH}"
[[section]]
# Section 3: TFLiteModel (Prefill/Decode)
section_type = "TFLiteModel"
model_type = "PREFILL_DECODE"
data_path = "{PREFILL_DECODE_PATH}"
additional_metadata = [
{ key = "License", value_type = "String", value = "Example" }
]
"""
class LitertlmBuilderTest(parameterized.TestCase):
def setUp(self):
super().setUp()
self.temp_dir = self.create_tempdir().full_path
def _create_dummy_file(self, filename: str, content: bytes) -> str:
filepath = os.path.join(self.temp_dir, filename)
with litertlm_core.open_file(filepath, "wb") as f:
f.write(content)
return filepath
def _add_system_metadata(self, builder: litertlm_builder.LitertLmFileBuilder):
builder.add_system_metadata(
litertlm_builder.Metadata(
key="sys_test_k",
value="sys_test_v",
dtype=litertlm_builder.DType.STRING,
)
)
def _build_and_read_litertlm(
self, builder: litertlm_builder.LitertLmFileBuilder
) -> str:
path = os.path.join(self.temp_dir, "litertlm.litertlm")
with litertlm_core.open_file(path, "wb") as f:
builder.build(f)
stream = io.StringIO()
litertlm_peek.peek_litertlm_file(path, self.temp_dir, stream)
return stream.getvalue()
def test_add_system_metadata(self):
"""Tests that system metadata is added correctly."""
builder = litertlm_builder.LitertLmFileBuilder()
self._add_system_metadata(builder)
ss = self._build_and_read_litertlm(builder)
self.assertIn("Key: sys_test_k, Value (String): sys_test_v", ss)
self.assertIn("Sections (0)", ss)
def test_add_system_metadata_duplicate_key(self):
"""Tests that adding system metadata with a duplicate key raises a ValueError."""
builder = litertlm_builder.LitertLmFileBuilder()
builder.add_system_metadata(
litertlm_builder.Metadata(
key="sys_key1",
value="sys_val1",
dtype=litertlm_builder.DType.STRING,
)
)
with self.assertRaises(ValueError):
builder.add_system_metadata(
litertlm_builder.Metadata(
key="sys_key1",
value="sys_val2",
dtype=litertlm_builder.DType.STRING,
)
)
def test_add_llm_metadata_binary(self):
"""Tests that LLM metadata can be added from a binary proto file."""
llm_metadata = llm_metadata_pb2.LlmMetadata(max_num_tokens=123)
bin_proto = llm_metadata.SerializeToString()
metadata_path = self._create_dummy_file("llm.pb", bin_proto)
builder = litertlm_builder.LitertLmFileBuilder()
self._add_system_metadata(builder)
builder.add_llm_metadata(metadata_path)
ss = self._build_and_read_litertlm(builder)
self.assertIn("max_num_tokens: 123", ss)
self.assertIn("Sections (1)", ss)
def test_add_llm_metadata_text(self):
"""Tests that LLM metadata can be added from a text proto file."""
llm_metadata = llm_metadata_pb2.LlmMetadata(max_num_tokens=123)
text_proto = text_format.MessageToString(llm_metadata)
metadata_path = self._create_dummy_file(
"llm.textproto", text_proto.encode("utf-8")
)
builder = litertlm_builder.LitertLmFileBuilder()
self._add_system_metadata(builder)
builder.add_llm_metadata(metadata_path)
ss = self._build_and_read_litertlm(builder)
self.assertIn("max_num_tokens: 123", ss)
self.assertIn("Sections (1)", ss)
def test_add_llm_metadata_not_found(self):
"""Tests that adding a non-existent LLM metadata file raises a FileNotFoundError."""
builder = litertlm_builder.LitertLmFileBuilder()
with self.assertRaises(FileNotFoundError):
builder.add_llm_metadata("nonexistent.pb")
def test_add_llm_metadata_already_added(self):
builder = litertlm_builder.LitertLmFileBuilder()
metadata_path = self._create_dummy_file("llm.pb", b"")
builder.add_llm_metadata(metadata_path)
with self.assertRaises(AssertionError):
builder.add_llm_metadata(metadata_path)
@parameterized.named_parameters(
("prefill_decode", litertlm_builder.TfLiteModelType.PREFILL_DECODE),
("mtp_drafter", litertlm_builder.TfLiteModelType.MTP_DRAFTER),
)
def test_add_tflite_model(self, model_type: litertlm_builder.TfLiteModelType):
"""Tests that a TFLite model can be added correctly."""
tflite_path = self._create_dummy_file(
"model.tflite", b"dummy tflite content"
)
builder = litertlm_builder.LitertLmFileBuilder()
self._add_system_metadata(builder)
builder.add_tflite_model(
tflite_path,
model_type,
additional_metadata=[
litertlm_builder.Metadata(
key="test_key",
value="test_value",
dtype=litertlm_builder.DType.STRING,
)
],
)
ss = self._build_and_read_litertlm(builder)
self.assertIn("Sections (1)", ss)
self.assertIn("Data Type: TFLiteModel", ss)
self.assertIn(f"Key: model_type, Value (String): {model_type.value}", ss)
self.assertIn("Key: test_key, Value (String): test_value", ss)
def test_add_tflite_model_with_backend_constraint(self):
"""Tests that a TFLite model with backend constraint added correctly."""
tflite_path = self._create_dummy_file(
"model.tflite", b"dummy tflite content"
)
builder = litertlm_builder.LitertLmFileBuilder()
self._add_system_metadata(builder)
builder.add_tflite_model(
tflite_path,
litertlm_builder.TfLiteModelType.PREFILL_DECODE,
backend_constraint="gpu",
)
ss = self._build_and_read_litertlm(builder)
self.assertIn("Sections (1)", ss)
self.assertIn("Data Type: TFLiteModel", ss)
self.assertIn("Key: model_type, Value (String): tf_lite_prefill_decode", ss)
self.assertIn("Key: backend_constraint, Value (String): gpu", ss)
def test_add_tflite_model_with_multiple_backend_constraint(self):
"""Tests that a TFLite model with backend constraint added correctly."""
tflite_path = self._create_dummy_file(
"model.tflite", b"dummy tflite content"
)
builder = litertlm_builder.LitertLmFileBuilder()
self._add_system_metadata(builder)
builder.add_tflite_model(
tflite_path,
litertlm_builder.TfLiteModelType.PREFILL_DECODE,
backend_constraint="cpu, GPU",
)
ss = self._build_and_read_litertlm(builder)
self.assertIn("Sections (1)", ss)
self.assertIn("Data Type: TFLiteModel", ss)
self.assertIn("Key: model_type, Value (String): tf_lite_prefill_decode", ss)
self.assertIn("Key: backend_constraint, Value (String): cpu, gpu", ss)
def test_add_tflite_model_with_invalid_backend_constraint(self):
"""Tests that a TFLite model with backend constraint added correctly."""
tflite_path = self._create_dummy_file(
"model.tflite", b"dummy tflite content"
)
builder = litertlm_builder.LitertLmFileBuilder()
self._add_system_metadata(builder)
with self.assertRaisesRegex(ValueError, "Invalid backend constraint"):
builder.add_tflite_model(
tflite_path,
litertlm_builder.TfLiteModelType.PREFILL_DECODE,
backend_constraint="foo, bar",
)
def test_add_tflite_model_override_type(self):
"""Tests that overriding the model type in additional metadata raises a ValueError."""
tflite_path = self._create_dummy_file(
"model.tflite", b"dummy tflite content"
)
additional_metadata = [
litertlm_builder.Metadata(
key="model_type", value="bad", dtype=litertlm_builder.DType.STRING
)
]
builder = litertlm_builder.LitertLmFileBuilder()
self._add_system_metadata(builder)
with self.assertRaises(ValueError):
builder.add_tflite_model(
tflite_path,
litertlm_builder.TfLiteModelType.EMBEDDER,
additional_metadata=additional_metadata,
)
def test_add_tflite_weights(self):
"""Tests that a TFLite weights file can be added correctly."""
tflite_weights_path = self._create_dummy_file(
"model.weights", b"dummy tflite weights content"
)
builder = litertlm_builder.LitertLmFileBuilder()
self._add_system_metadata(builder)
builder.add_tflite_weights(
tflite_weights_path,
litertlm_builder.TfLiteModelType.PREFILL_DECODE,
additional_metadata=[
litertlm_builder.Metadata(
key="test_key",
value="test_value",
dtype=litertlm_builder.DType.STRING,
)
],
)
ss = self._build_and_read_litertlm(builder)
self.assertIn("Sections (1)", ss)
self.assertIn("Data Type: TFLiteWeights", ss)
self.assertIn("Key: model_type, Value (String): tf_lite_prefill_decode", ss)
self.assertIn("Key: test_key, Value (String): test_value", ss)
def test_add_sentencepiece_tokenizer(self):
"""Tests that a SentencePiece tokenizer can be added correctly."""
sp_path = self._create_dummy_file("sp.model", b"dummy sp content")
additional_metadata = [
litertlm_builder.Metadata(
key="test_key",
value="test_value",
dtype=litertlm_builder.DType.STRING,
)
]
builder = litertlm_builder.LitertLmFileBuilder()
self._add_system_metadata(builder)
builder.add_sentencepiece_tokenizer(
sp_path, additional_metadata=additional_metadata
)
ss = self._build_and_read_litertlm(builder)
print(ss)
self.assertIn("Sections (1)", ss)
self.assertIn("Data Type: SP_Tokenizer", ss)
self.assertIn("Key: test_key, Value (String): test_value", ss)
def test_add_hf_tokenizer(self):
"""Tests that a HuggingFace tokenizer can be added correctly."""
hf_content = b'{"version": "1.0"}'
hf_path = self._create_dummy_file("tokenizer.json", hf_content)
additional_metadata = [
litertlm_builder.Metadata(
key="test_key",
value="test_value",
dtype=litertlm_builder.DType.STRING,
)
]
builder = litertlm_builder.LitertLmFileBuilder()
self._add_system_metadata(builder)
builder.add_hf_tokenizer(hf_path, additional_metadata=additional_metadata)
ss = self._build_and_read_litertlm(builder)
self.assertIn("Sections (1)", ss)
self.assertIn("Data Type: HF_Tokenizer_Zlib", ss)
self.assertIn("Key: test_key, Value (String): test_value", ss)
# Verify content compression
with litertlm_core.open_file(
os.path.join(self.temp_dir, "litertlm.litertlm"), "rb"
) as f:
f.seek(litertlm_core.BLOCK_SIZE)
# Read uncompressed size (8 bytes)
uncompressed_size = int.from_bytes(f.read(8), "little")
self.assertLen(hf_content, uncompressed_size)
# Read remaining data (compressed)
compressed_data = f.read()
# Decompress and verify. zlib.decompress will stop at end of stream,
# ignoring padding
decompressed = zlib.decompress(compressed_data)
self.assertEqual(decompressed, hf_content)
def test_add_hf_tokenizer_zlib(self):
"""Tests that a zipped HuggingFace tokenizer is handled correctly."""
zlib_content = b"dummy zlib content"
hf_path = self._create_dummy_file("tokenizer.zlib", zlib_content)
additional_metadata = [
litertlm_builder.Metadata(
key="test_key",
value="test_value",
dtype=litertlm_builder.DType.STRING,
)
]
builder = litertlm_builder.LitertLmFileBuilder()
self._add_system_metadata(builder)
builder.add_hf_tokenizer(hf_path, additional_metadata=additional_metadata)
ss = self._build_and_read_litertlm(builder)
self.assertIn("Sections (1)", ss)
self.assertIn("Data Type: HF_Tokenizer_Zlib", ss)
self.assertIn("Key: test_key, Value (String): test_value", ss)
# Verify content is raw (not re-compressed and no size prefix)
with litertlm_core.open_file(
os.path.join(self.temp_dir, "litertlm.litertlm"), "rb"
) as f:
f.seek(litertlm_core.BLOCK_SIZE)
# Should match exact content immediately
read_content = f.read(len(zlib_content))
self.assertEqual(read_content, zlib_content)
def test_add_tokenizer_already_added(self):
"""Tests that adding a tokenizer more than once raises an AssertionError."""
sp_path = self._create_dummy_file("sp.model", b"")
builder = litertlm_builder.LitertLmFileBuilder()
self._add_system_metadata(builder)
builder.add_sentencepiece_tokenizer(sp_path)
with self.assertRaises(AssertionError):
builder.add_hf_tokenizer(self._create_dummy_file("tokenizer.json", b""))
with self.assertRaises(AssertionError):
builder.add_sentencepiece_tokenizer(
self._create_dummy_file("tokenizer.json", b"")
)
def test_end_to_end(self):
"""Tests a more complex end-to-end scenario with multiple sections."""
sp_path = self._create_dummy_file("sp.model", b"dummy sp content")
tflite_path = self._create_dummy_file(
"model.tflite", b"dummy tflite content"
)
llm_metadata = llm_metadata_pb2.LlmMetadata(max_num_tokens=123)
bin_proto = llm_metadata.SerializeToString()
metadata_path = self._create_dummy_file("llm.pb", bin_proto)
builder = litertlm_builder.LitertLmFileBuilder()
self._add_system_metadata(builder)
builder.add_sentencepiece_tokenizer(sp_path)
builder.add_tflite_model(
tflite_path, model_type=litertlm_builder.TfLiteModelType.EMBEDDER
)
builder.add_tflite_model(
tflite_path, model_type=litertlm_builder.TfLiteModelType.PREFILL_DECODE
)
builder.add_llm_metadata(metadata_path)
ss = self._build_and_read_litertlm(builder)
self.assertIn("Sections (4)", ss)
self.assertIn("Data Type: SP_Tokenizer", ss)
self.assertIn("Data Type: TFLiteModel", ss)
self.assertIn("Key: model_type, Value (String): tf_lite_embedder", ss)
self.assertIn("Key: model_type, Value (String): tf_lite_prefill_decode", ss)
self.assertIn("Data Type: LlmMetadataProto", ss)
self.assertIn("max_num_tokens: 123", ss)
@parameterized.named_parameters(
("relative_path", True),
("absolute_path", False),
)
def test_from_toml(self, use_relative_path: bool):
"""Tests that a LitertLmFileBuilder can be initialized from a TOML file."""
sp_filename = "sp.model"
tflite_filename = "model.tflite"
metadata_filename = "llm.pb"
sp_path_abs = self._create_dummy_file(sp_filename, b"dummy sp content")
tflite_path_abs = self._create_dummy_file(
tflite_filename, b"dummy tflite content"
)
metadata_path_abs = self._create_dummy_file(
metadata_filename,
llm_metadata_pb2.LlmMetadata(max_num_tokens=123).SerializeToString(),
)
if use_relative_path:
sp_path = sp_filename
tflite_path = tflite_filename
metadata_path = metadata_filename
else:
sp_path = pathlib.Path(sp_path_abs).as_posix()
tflite_path = pathlib.Path(tflite_path_abs).as_posix()
metadata_path = pathlib.Path(metadata_path_abs).as_posix()
toml_path = self._create_dummy_file(
"test.toml",
_TOML_TEMPLATE.replace("{LLM_METADATA_PATH}", metadata_path)
.replace("{SP_TOKENIZER_PATH}", sp_path)
.replace("{EMBEDDER_PATH}", tflite_path)
.replace("{PREFILL_DECODE_PATH}", tflite_path)
.encode("utf-8"),
)
builder = litertlm_builder.LitertLmFileBuilder.from_toml_file(toml_path)
ss = self._build_and_read_litertlm(builder)
self.assertIn("Sections (4)", ss)
self.assertIn("Data Type: SP_Tokenizer", ss)
self.assertIn("Data Type: TFLiteModel", ss)
self.assertIn("Key: model_type, Value (String): tf_lite_embedder", ss)
self.assertIn("Key: model_type, Value (String): tf_lite_prefill_decode", ss)
self.assertIn("Data Type: LlmMetadataProto", ss)
self.assertIn("max_num_tokens: 123", ss)
if __name__ == "__main__":
absltest.main()