LiteRT-LM / schema /py /litertlm_builder_cli_test.py
SeaWolf-AI's picture
Upload full LiteRT-LM codebase
5f923cd verified
# Copyright 2025 The ODML Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import io
import os
import pathlib
import subprocess
from absl.testing import absltest
from litert_lm.runtime.proto import llm_metadata_pb2
from litert_lm.schema.py import litertlm_core
from litert_lm.schema.py import litertlm_peek
from python import runfiles
_TOML_TEMPLATE = """
# A template for testing the TOML parser.
[system_metadata]
entries = [
{ key = "author", value_type = "String", value = "The ODML Authors" }
]
[[section]]
# Section 0: LlmMetadataProto
section_type = "LlmMetadata"
data_path = "{LLM_METADATA_PATH}"
[[section]]
# Section 1: HF_Tokenizer
section_type = "HF_Tokenizer"
data_path = "{HF_TOKENIZER_PATH}"
[[section]]
# Section 2: TFLiteModel (Embedder)
section_type = "TFLiteModel"
model_type = "EMBEDDER"
data_path = "{EMBEDDER_PATH}"
[[section]]
# Section 3: TFLiteModel (Prefill/Decode)
section_type = "TFLiteModel"
model_type = "PREFILL_DECODE"
backend_constraint = "GPU"
data_path = "{PREFILL_DECODE_PATH}"
additional_metadata = [
{ key = "License", value_type = "String", value = "Example" }
]
"""
class LiteRTLMBuilderCLITest(absltest.TestCase):
def setUp(self):
super().setUp()
self.temp_dir = self.create_tempdir().full_path
def _create_placeholder_file(self, filename: str, content: bytes) -> str:
filepath = os.path.join(self.temp_dir, filename)
with litertlm_core.open_file(filepath, "wb") as f:
f.write(content)
return filepath
def _get_command_path(self) -> str:
"""Returns the path to the command binary."""
r = runfiles.Create()
return r.Rlocation(
os.path.join(
os.path.dirname(os.path.abspath(__file__)), "litertlm_builder_cli"
)
)
def _run_command(self, *args) -> str:
"""Runs the command with the given arguments."""
output_path = os.path.join(self.temp_dir, "litertlm.litertlm")
command = [
self._get_command_path(),
*args,
"output",
"--path",
output_path,
]
try:
subprocess.run(command, check=True, capture_output=True)
except subprocess.CalledProcessError as e:
print("command stdout:\n", e.stdout.decode("utf-8"))
print("command stderr:\n", e.stderr.decode("utf-8"))
raise e
return output_path
def _peek_litertlm_file(self, path: str) -> str:
"""Peeks the litertlm file and returns the string representation."""
stream = io.StringIO()
litertlm_peek.peek_litertlm_file(path, self.temp_dir, stream)
return stream.getvalue()
def test_system_metadata(self):
"""Tests that system metadata can be added correctly."""
args = ["system_metadata", "--str", "key1", "value1"]
output_path = self._run_command(*args)
self.assertTrue(os.path.exists(output_path))
ss = self._peek_litertlm_file(output_path)
self.assertIn("Key: key1, Value (String): value1", ss)
self.assertIn("Sections (0)", ss)
def test_llm_metadata(self):
"""Tests that LLM metadata can be added from a binary proto file."""
llm_metadata = llm_metadata_pb2.LlmMetadata(max_num_tokens=123)
bin_proto = llm_metadata.SerializeToString()
metadata_path = self._create_placeholder_file("llm.pb", bin_proto)
args = [
"system_metadata",
"--int",
"my_key",
"23",
"llm_metadata",
"--path",
metadata_path,
]
output_path = self._run_command(*args)
self.assertTrue(os.path.exists(output_path))
ss = self._peek_litertlm_file(output_path)
self.assertIn("max_num_tokens: 123", ss)
self.assertIn("Sections (1)", ss)
def test_tflite_model(self):
"""Tests that a TFLite model can be added correctly."""
tflite_path = self._create_placeholder_file(
"model.tflite", b"dummy tflite content"
)
args = [
"system_metadata",
"--int",
"my_key",
"23",
"tflite_model",
"--path",
tflite_path,
"--model_type",
"prefill_decode",
"--str_metadata",
"model_version",
"1.0.1",
"--backend_constraint",
"CPU",
]
output_path = self._run_command(*args)
self.assertTrue(os.path.exists(output_path))
ss = self._peek_litertlm_file(output_path)
self.assertIn("Sections (1)", ss)
self.assertIn("Data Type: TFLiteModel", ss)
self.assertIn("Key: model_type, Value (String): tf_lite_prefill_decode", ss)
self.assertIn("Key: model_version, Value (String): 1.0.1", ss)
self.assertIn("Key: backend_constraint, Value (String): cpu", ss)
def test_tflite_weights(self):
"""Tests that TFLite weights can be added correctly via CLI."""
tflite_path = self._create_placeholder_file(
"model.weights", b"dummy tflite weights content"
)
args = [
"system_metadata",
"--int",
"my_key",
"23",
"tflite_weights",
"--path",
tflite_path,
"--model_type",
"prefill_decode",
"--str_metadata",
"weights_version",
"1.0.1",
]
output_path = self._run_command(*args)
self.assertTrue(os.path.exists(output_path))
ss = self._peek_litertlm_file(output_path)
self.assertIn("Sections (1)", ss)
self.assertIn("Data Type: TFLiteWeights", ss)
self.assertIn("Key: model_type, Value (String): tf_lite_prefill_decode", ss)
self.assertIn("Key: weights_version, Value (String): 1.0.1", ss)
def test_sp_tokenizer(self):
"""Tests that a SentencePiece tokenizer can be added correctly."""
sp_path = self._create_placeholder_file("sp.model", b"dummy sp content")
args = [
"system_metadata",
"--int",
"my_key",
"23",
"sp_tokenizer",
"--path",
sp_path,
"--str_metadata",
"tokenizer_version",
"1.0.1",
]
output_path = self._run_command(*args)
self.assertTrue(os.path.exists(output_path))
ss = self._peek_litertlm_file(output_path)
self.assertIn("Sections (1)", ss)
self.assertIn("Data Type: SP_Tokenizer", ss)
self.assertIn("Key: tokenizer_version, Value (String): 1.0.1", ss)
def test_hf_tokenizer(self):
"""Tests that a HuggingFace tokenizer can be added correctly."""
hf_path = self._create_placeholder_file(
"tokenizer.json", b'{"version": "1.0"}'
)
args = [
"system_metadata",
"--int",
"my_key",
"23",
"hf_tokenizer",
"--path",
hf_path,
"--str_metadata",
"tokenizer_version",
"1.0.1",
]
output_path = self._run_command(*args)
self.assertTrue(os.path.exists(output_path))
ss = self._peek_litertlm_file(output_path)
self.assertIn("Sections (1)", ss)
self.assertIn("Data Type: HF_Tokenizer_Zlib", ss)
self.assertIn("Key: tokenizer_version, Value (String): 1.0.1", ss)
def test_end_to_end(self):
"""Tests a more complex end-to-end scenario with multiple sections."""
sp_path = self._create_placeholder_file("sp.model", b"dummy sp content")
tflite_path = self._create_placeholder_file(
"model.tflite", b"dummy tflite content"
)
llm_metadata = llm_metadata_pb2.LlmMetadata(max_num_tokens=123)
bin_proto = llm_metadata.SerializeToString()
metadata_path = self._create_placeholder_file("llm.pb", bin_proto)
args = [
"system_metadata",
"--str",
"Authors",
"ODML team",
"sp_tokenizer",
"--path",
sp_path,
"tflite_model",
"--path",
tflite_path,
"--model_type",
"embedder",
"tflite_model",
"--path",
tflite_path,
"--model_type",
"prefill_decode",
"--backend_constraint",
"GPU",
"llm_metadata",
"--path",
metadata_path,
]
output_path = self._run_command(*args)
self.assertTrue(os.path.exists(output_path))
ss = self._peek_litertlm_file(output_path)
self.assertIn("Sections (4)", ss)
self.assertIn("Data Type: SP_Tokenizer", ss)
self.assertIn("Data Type: TFLiteModel", ss)
self.assertIn("Key: model_type, Value (String): tf_lite_embedder", ss)
self.assertIn("Key: model_type, Value (String): tf_lite_prefill_decode", ss)
self.assertIn("Key: backend_constraint, Value (String): gpu", ss)
self.assertIn("Data Type: LlmMetadataProto", ss)
self.assertIn("max_num_tokens: 123", ss)
def test_toml_file(self):
"""Tests that a TOML file can be added correctly."""
hf_path = pathlib.Path(
self._create_placeholder_file("tokenizer.json", b'{"version": "1.0"}')
).as_posix()
tflite_path = pathlib.Path(
self._create_placeholder_file("model.tflite", b"dummy tflite content")
).as_posix()
llm_metadata = llm_metadata_pb2.LlmMetadata(max_num_tokens=123)
bin_proto = llm_metadata.SerializeToString()
metadata_path = pathlib.Path(
self._create_placeholder_file("llm.pb", bin_proto)
).as_posix()
toml_path = self._create_placeholder_file(
"test.toml",
_TOML_TEMPLATE.replace("{LLM_METADATA_PATH}", metadata_path)
.replace("{HF_TOKENIZER_PATH}", hf_path)
.replace("{EMBEDDER_PATH}", tflite_path)
.replace("{PREFILL_DECODE_PATH}", tflite_path)
.encode("utf-8"),
)
args = [
"toml",
"--path",
toml_path,
]
output_path = self._run_command(*args)
self.assertTrue(os.path.exists(output_path))
ss = self._peek_litertlm_file(output_path)
self.assertIn("Sections (4)", ss)
self.assertIn("Data Type: HF_Tokenizer_Zlib", ss)
self.assertIn("Data Type: TFLiteModel", ss)
self.assertIn("Key: model_type, Value (String): tf_lite_embedder", ss)
self.assertIn("Key: model_type, Value (String): tf_lite_prefill_decode", ss)
self.assertIn("Key: backend_constraint, Value (String): gpu", ss)
self.assertIn("Data Type: LlmMetadataProto", ss)
self.assertIn("max_num_tokens: 123", ss)
def test_toml_cannot_be_used_with_other_args(self):
"""Tests that a TOML file cannot be used with other args."""
tflite_path = self._create_placeholder_file(
"model.tflite", b"dummy tflite content"
)
toml_path = self._create_placeholder_file(
"test.toml",
_TOML_TEMPLATE.replace("{PREFILL_DECODE_PATH}", tflite_path).encode(
"utf-8"
),
)
args = [
"toml",
"--path",
toml_path,
"system_metadata",
"--str",
"key1",
"value1",
]
with self.assertRaises(subprocess.CalledProcessError):
self._run_command(*args)
def test_help_root(self):
"""Tests that the help command prints the correct output."""
command = [self._get_command_path(), "--help"]
output = subprocess.run(command, check=True, capture_output=True)
self.assertEqual(output.returncode, 0)
stdout = output.stdout.decode("utf-8")
self.assertIn(
"Build a LiteRT-LM file from input files and metadata", stdout
)
def test_help_subcommand(self):
"""Tests that the help command prints the correct output for subcommand."""
command = [self._get_command_path(), "system_metadata", "--help"]
output = subprocess.run(command, check=True, capture_output=True)
self.assertEqual(output.returncode, 0)
stdout = output.stdout.decode("utf-8")
self.assertIn(
"Add one or more system metadata key-value pairs to the LiteRT-LM"
" file.",
stdout,
)
if __name__ == "__main__":
absltest.main()