Spaces:
Running
Running
| # Copyright 2025 The ODML Authors. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import io | |
| import os | |
| import pathlib | |
| import subprocess | |
| from absl.testing import absltest | |
| from litert_lm.runtime.proto import llm_metadata_pb2 | |
| from litert_lm.schema.py import litertlm_core | |
| from litert_lm.schema.py import litertlm_peek | |
| from python import runfiles | |
| _TOML_TEMPLATE = """ | |
| # A template for testing the TOML parser. | |
| [system_metadata] | |
| entries = [ | |
| { key = "author", value_type = "String", value = "The ODML Authors" } | |
| ] | |
| [[section]] | |
| # Section 0: LlmMetadataProto | |
| section_type = "LlmMetadata" | |
| data_path = "{LLM_METADATA_PATH}" | |
| [[section]] | |
| # Section 1: HF_Tokenizer | |
| section_type = "HF_Tokenizer" | |
| data_path = "{HF_TOKENIZER_PATH}" | |
| [[section]] | |
| # Section 2: TFLiteModel (Embedder) | |
| section_type = "TFLiteModel" | |
| model_type = "EMBEDDER" | |
| data_path = "{EMBEDDER_PATH}" | |
| [[section]] | |
| # Section 3: TFLiteModel (Prefill/Decode) | |
| section_type = "TFLiteModel" | |
| model_type = "PREFILL_DECODE" | |
| backend_constraint = "GPU" | |
| data_path = "{PREFILL_DECODE_PATH}" | |
| additional_metadata = [ | |
| { key = "License", value_type = "String", value = "Example" } | |
| ] | |
| """ | |
| class LiteRTLMBuilderCLITest(absltest.TestCase): | |
| def setUp(self): | |
| super().setUp() | |
| self.temp_dir = self.create_tempdir().full_path | |
| def _create_placeholder_file(self, filename: str, content: bytes) -> str: | |
| filepath = os.path.join(self.temp_dir, filename) | |
| with litertlm_core.open_file(filepath, "wb") as f: | |
| f.write(content) | |
| return filepath | |
| def _get_command_path(self) -> str: | |
| """Returns the path to the command binary.""" | |
| r = runfiles.Create() | |
| return r.Rlocation( | |
| os.path.join( | |
| os.path.dirname(os.path.abspath(__file__)), "litertlm_builder_cli" | |
| ) | |
| ) | |
| def _run_command(self, *args) -> str: | |
| """Runs the command with the given arguments.""" | |
| output_path = os.path.join(self.temp_dir, "litertlm.litertlm") | |
| command = [ | |
| self._get_command_path(), | |
| *args, | |
| "output", | |
| "--path", | |
| output_path, | |
| ] | |
| try: | |
| subprocess.run(command, check=True, capture_output=True) | |
| except subprocess.CalledProcessError as e: | |
| print("command stdout:\n", e.stdout.decode("utf-8")) | |
| print("command stderr:\n", e.stderr.decode("utf-8")) | |
| raise e | |
| return output_path | |
| def _peek_litertlm_file(self, path: str) -> str: | |
| """Peeks the litertlm file and returns the string representation.""" | |
| stream = io.StringIO() | |
| litertlm_peek.peek_litertlm_file(path, self.temp_dir, stream) | |
| return stream.getvalue() | |
| def test_system_metadata(self): | |
| """Tests that system metadata can be added correctly.""" | |
| args = ["system_metadata", "--str", "key1", "value1"] | |
| output_path = self._run_command(*args) | |
| self.assertTrue(os.path.exists(output_path)) | |
| ss = self._peek_litertlm_file(output_path) | |
| self.assertIn("Key: key1, Value (String): value1", ss) | |
| self.assertIn("Sections (0)", ss) | |
| def test_llm_metadata(self): | |
| """Tests that LLM metadata can be added from a binary proto file.""" | |
| llm_metadata = llm_metadata_pb2.LlmMetadata(max_num_tokens=123) | |
| bin_proto = llm_metadata.SerializeToString() | |
| metadata_path = self._create_placeholder_file("llm.pb", bin_proto) | |
| args = [ | |
| "system_metadata", | |
| "--int", | |
| "my_key", | |
| "23", | |
| "llm_metadata", | |
| "--path", | |
| metadata_path, | |
| ] | |
| output_path = self._run_command(*args) | |
| self.assertTrue(os.path.exists(output_path)) | |
| ss = self._peek_litertlm_file(output_path) | |
| self.assertIn("max_num_tokens: 123", ss) | |
| self.assertIn("Sections (1)", ss) | |
| def test_tflite_model(self): | |
| """Tests that a TFLite model can be added correctly.""" | |
| tflite_path = self._create_placeholder_file( | |
| "model.tflite", b"dummy tflite content" | |
| ) | |
| args = [ | |
| "system_metadata", | |
| "--int", | |
| "my_key", | |
| "23", | |
| "tflite_model", | |
| "--path", | |
| tflite_path, | |
| "--model_type", | |
| "prefill_decode", | |
| "--str_metadata", | |
| "model_version", | |
| "1.0.1", | |
| "--backend_constraint", | |
| "CPU", | |
| ] | |
| output_path = self._run_command(*args) | |
| self.assertTrue(os.path.exists(output_path)) | |
| ss = self._peek_litertlm_file(output_path) | |
| self.assertIn("Sections (1)", ss) | |
| self.assertIn("Data Type: TFLiteModel", ss) | |
| self.assertIn("Key: model_type, Value (String): tf_lite_prefill_decode", ss) | |
| self.assertIn("Key: model_version, Value (String): 1.0.1", ss) | |
| self.assertIn("Key: backend_constraint, Value (String): cpu", ss) | |
| def test_tflite_weights(self): | |
| """Tests that TFLite weights can be added correctly via CLI.""" | |
| tflite_path = self._create_placeholder_file( | |
| "model.weights", b"dummy tflite weights content" | |
| ) | |
| args = [ | |
| "system_metadata", | |
| "--int", | |
| "my_key", | |
| "23", | |
| "tflite_weights", | |
| "--path", | |
| tflite_path, | |
| "--model_type", | |
| "prefill_decode", | |
| "--str_metadata", | |
| "weights_version", | |
| "1.0.1", | |
| ] | |
| output_path = self._run_command(*args) | |
| self.assertTrue(os.path.exists(output_path)) | |
| ss = self._peek_litertlm_file(output_path) | |
| self.assertIn("Sections (1)", ss) | |
| self.assertIn("Data Type: TFLiteWeights", ss) | |
| self.assertIn("Key: model_type, Value (String): tf_lite_prefill_decode", ss) | |
| self.assertIn("Key: weights_version, Value (String): 1.0.1", ss) | |
| def test_sp_tokenizer(self): | |
| """Tests that a SentencePiece tokenizer can be added correctly.""" | |
| sp_path = self._create_placeholder_file("sp.model", b"dummy sp content") | |
| args = [ | |
| "system_metadata", | |
| "--int", | |
| "my_key", | |
| "23", | |
| "sp_tokenizer", | |
| "--path", | |
| sp_path, | |
| "--str_metadata", | |
| "tokenizer_version", | |
| "1.0.1", | |
| ] | |
| output_path = self._run_command(*args) | |
| self.assertTrue(os.path.exists(output_path)) | |
| ss = self._peek_litertlm_file(output_path) | |
| self.assertIn("Sections (1)", ss) | |
| self.assertIn("Data Type: SP_Tokenizer", ss) | |
| self.assertIn("Key: tokenizer_version, Value (String): 1.0.1", ss) | |
| def test_hf_tokenizer(self): | |
| """Tests that a HuggingFace tokenizer can be added correctly.""" | |
| hf_path = self._create_placeholder_file( | |
| "tokenizer.json", b'{"version": "1.0"}' | |
| ) | |
| args = [ | |
| "system_metadata", | |
| "--int", | |
| "my_key", | |
| "23", | |
| "hf_tokenizer", | |
| "--path", | |
| hf_path, | |
| "--str_metadata", | |
| "tokenizer_version", | |
| "1.0.1", | |
| ] | |
| output_path = self._run_command(*args) | |
| self.assertTrue(os.path.exists(output_path)) | |
| ss = self._peek_litertlm_file(output_path) | |
| self.assertIn("Sections (1)", ss) | |
| self.assertIn("Data Type: HF_Tokenizer_Zlib", ss) | |
| self.assertIn("Key: tokenizer_version, Value (String): 1.0.1", ss) | |
| def test_end_to_end(self): | |
| """Tests a more complex end-to-end scenario with multiple sections.""" | |
| sp_path = self._create_placeholder_file("sp.model", b"dummy sp content") | |
| tflite_path = self._create_placeholder_file( | |
| "model.tflite", b"dummy tflite content" | |
| ) | |
| llm_metadata = llm_metadata_pb2.LlmMetadata(max_num_tokens=123) | |
| bin_proto = llm_metadata.SerializeToString() | |
| metadata_path = self._create_placeholder_file("llm.pb", bin_proto) | |
| args = [ | |
| "system_metadata", | |
| "--str", | |
| "Authors", | |
| "ODML team", | |
| "sp_tokenizer", | |
| "--path", | |
| sp_path, | |
| "tflite_model", | |
| "--path", | |
| tflite_path, | |
| "--model_type", | |
| "embedder", | |
| "tflite_model", | |
| "--path", | |
| tflite_path, | |
| "--model_type", | |
| "prefill_decode", | |
| "--backend_constraint", | |
| "GPU", | |
| "llm_metadata", | |
| "--path", | |
| metadata_path, | |
| ] | |
| output_path = self._run_command(*args) | |
| self.assertTrue(os.path.exists(output_path)) | |
| ss = self._peek_litertlm_file(output_path) | |
| self.assertIn("Sections (4)", ss) | |
| self.assertIn("Data Type: SP_Tokenizer", ss) | |
| self.assertIn("Data Type: TFLiteModel", ss) | |
| self.assertIn("Key: model_type, Value (String): tf_lite_embedder", ss) | |
| self.assertIn("Key: model_type, Value (String): tf_lite_prefill_decode", ss) | |
| self.assertIn("Key: backend_constraint, Value (String): gpu", ss) | |
| self.assertIn("Data Type: LlmMetadataProto", ss) | |
| self.assertIn("max_num_tokens: 123", ss) | |
| def test_toml_file(self): | |
| """Tests that a TOML file can be added correctly.""" | |
| hf_path = pathlib.Path( | |
| self._create_placeholder_file("tokenizer.json", b'{"version": "1.0"}') | |
| ).as_posix() | |
| tflite_path = pathlib.Path( | |
| self._create_placeholder_file("model.tflite", b"dummy tflite content") | |
| ).as_posix() | |
| llm_metadata = llm_metadata_pb2.LlmMetadata(max_num_tokens=123) | |
| bin_proto = llm_metadata.SerializeToString() | |
| metadata_path = pathlib.Path( | |
| self._create_placeholder_file("llm.pb", bin_proto) | |
| ).as_posix() | |
| toml_path = self._create_placeholder_file( | |
| "test.toml", | |
| _TOML_TEMPLATE.replace("{LLM_METADATA_PATH}", metadata_path) | |
| .replace("{HF_TOKENIZER_PATH}", hf_path) | |
| .replace("{EMBEDDER_PATH}", tflite_path) | |
| .replace("{PREFILL_DECODE_PATH}", tflite_path) | |
| .encode("utf-8"), | |
| ) | |
| args = [ | |
| "toml", | |
| "--path", | |
| toml_path, | |
| ] | |
| output_path = self._run_command(*args) | |
| self.assertTrue(os.path.exists(output_path)) | |
| ss = self._peek_litertlm_file(output_path) | |
| self.assertIn("Sections (4)", ss) | |
| self.assertIn("Data Type: HF_Tokenizer_Zlib", ss) | |
| self.assertIn("Data Type: TFLiteModel", ss) | |
| self.assertIn("Key: model_type, Value (String): tf_lite_embedder", ss) | |
| self.assertIn("Key: model_type, Value (String): tf_lite_prefill_decode", ss) | |
| self.assertIn("Key: backend_constraint, Value (String): gpu", ss) | |
| self.assertIn("Data Type: LlmMetadataProto", ss) | |
| self.assertIn("max_num_tokens: 123", ss) | |
| def test_toml_cannot_be_used_with_other_args(self): | |
| """Tests that a TOML file cannot be used with other args.""" | |
| tflite_path = self._create_placeholder_file( | |
| "model.tflite", b"dummy tflite content" | |
| ) | |
| toml_path = self._create_placeholder_file( | |
| "test.toml", | |
| _TOML_TEMPLATE.replace("{PREFILL_DECODE_PATH}", tflite_path).encode( | |
| "utf-8" | |
| ), | |
| ) | |
| args = [ | |
| "toml", | |
| "--path", | |
| toml_path, | |
| "system_metadata", | |
| "--str", | |
| "key1", | |
| "value1", | |
| ] | |
| with self.assertRaises(subprocess.CalledProcessError): | |
| self._run_command(*args) | |
| def test_help_root(self): | |
| """Tests that the help command prints the correct output.""" | |
| command = [self._get_command_path(), "--help"] | |
| output = subprocess.run(command, check=True, capture_output=True) | |
| self.assertEqual(output.returncode, 0) | |
| stdout = output.stdout.decode("utf-8") | |
| self.assertIn( | |
| "Build a LiteRT-LM file from input files and metadata", stdout | |
| ) | |
| def test_help_subcommand(self): | |
| """Tests that the help command prints the correct output for subcommand.""" | |
| command = [self._get_command_path(), "system_metadata", "--help"] | |
| output = subprocess.run(command, check=True, capture_output=True) | |
| self.assertEqual(output.returncode, 0) | |
| stdout = output.stdout.decode("utf-8") | |
| self.assertIn( | |
| "Add one or more system metadata key-value pairs to the LiteRT-LM" | |
| " file.", | |
| stdout, | |
| ) | |
| if __name__ == "__main__": | |
| absltest.main() | |