File size: 15,493 Bytes
5f923cd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
# Copyright 2025 The ODML Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

r"""CLI tool for building LiteRT-LM files.

There are two ways to use this tool:

1. Building the file by specifying the components as CLI arguments:

```
bazel run //schema/py:litertlm_builder_cli -- \
  system_metadata --str Authors "ODML team" \
  llm_metadata --path llm.pb \
  tflite_model --path embedder.tflite --model_type embedder  --str_metadata model_version "1.0.1" \
  tflite_model --path model.tflite --model_type prefill_decode \
  sp_tokenizer --path sp.model \
  output --path output.litertlm
```
Notes:
- Constraints from litertlm_builder.py still apply.
- The order of the components in the CLI arguments determines the order of the
  sections in the output LiteRT-LM file.
- There can be multiple per section metadata.

2. Building the file by specifying the components as a TOML file:

TOML file example:
```
[system_metadata]
entries = [
  { key = "author", value_type = "String", value = "The ODML Authors" }
]

[[section]]
# Section 0: LlmMetadataProto Can be a text or binary proto file.
section_type = "LlmMetadata"
data_path = "PATH/TO/LLM_METADATA.pb"

[[section]]
# Section 1: SP_Tokenizer (you can also use HF_Tokenizer)
section_type = "SP_Tokenizer"
data_path = "PATH/TO/SP_TOKENIZER.model"

[[section]]
# Section 2: TFLiteModel (Embedder)
section_type = "TFLiteModel"
model_type = "EMBEDDER"
data_path = "PATH/TO/EMBEDDER.tflite"

[[section]]
# Section 3: TFLiteModel (Prefill/Decode)
section_type = "TFLiteModel"
model_type = "PREFILL_DECODE"
data_path = "PATH/TO/PREFILL_DECODE.tflite"
additional_metadata = [
  { key = "License", value_type = "String", value = "Example" }
  { key = "model_version", value_type = "String", value = "1.0.1" }
]
```

```
bazel run //schema/py:litertlm_builder_cli -- \
  toml --path example.toml output --path output.litertlm
```

"""

import argparse
import os
import sys
from typing import BinaryIO, cast

from absl import app

from litert_lm.schema.py import litertlm_builder
from litert_lm.schema.py import litertlm_core

_SUBCOMMANDS = (
    "toml",
    "system_metadata",
    "llm_metadata",
    "tflite_model",
    "tflite_weights",
    "sp_tokenizer",
    "hf_tokenizer",
    "output",
)


def _add_toml_parser(subparsers) -> None:
  """Adds a parser for TOML file to the subparsers."""
  toml_parser = subparsers.add_parser(
      "toml",
      description="Add a TOML file to the LiteRT-LM file.",
      help="Add a TOML file.",
  )
  toml_parser.add_argument(
      "--path",
      type=str,
      required=True,
      help="The path to the TOML file.",
  )


def _add_system_metadata_parser(subparsers) -> None:
  """Adds a parser for system metadata to the subparsers."""
  system_metadata_parser = subparsers.add_parser(
      "system_metadata",
      description=(
          "Add one or more system metadata key-value pairs to the LiteRT-LM"
          " file."
      ),
      help="Add system metadata.",
  )
  system_metadata_parser.add_argument(
      "--str",
      nargs=2,
      action="append",
      metavar=("KEY", "VALUE"),
      required=False,
      help=(
          "A string key-value pair for the system metadata. Can be specified"
          " multiple times."
      ),
  )
  system_metadata_parser.add_argument(
      "--int",
      nargs=2,
      action="append",
      metavar=("KEY", "VALUE"),
      required=False,
      help=(
          "An integer key-value pair for the system metadata. Can be specified"
          " multiple times."
      ),
  )


def _add_metadata_arguments(parser) -> None:
  """Adds arguments for metadata to the parser."""
  parser.add_argument(
      "--str_metadata",
      nargs=2,
      action="append",
      metavar=("KEY", "VALUE"),
      required=False,
      help=(
          "A string key-value pair for the metadata. Can be specified"
          " multiple times."
      ),
  )


def _add_llm_metadata_parser(subparsers) -> None:
  """Adds a parser for llm metadata to the subparsers."""
  llm_metadata_parser = subparsers.add_parser(
      "llm_metadata",
      description=(
          "Add llm metadata to the LiteRT-LM file. Can be a text or binary"
          " proto file."
      ),
      help="Add llm metadata.",
  )
  llm_metadata_parser.add_argument(
      "--path",
      type=str,
      required=True,
      help="The path to the llm metadata file.",
  )


def _add_tflite_model_parser(subparsers) -> None:
  """Adds a parser for tflite model to the subparsers."""
  tflite_model_parser = subparsers.add_parser(
      "tflite_model",
      description="Add a tflite model to the LiteRT-LM file.",
      help="Add a tflite model.",
  )
  tflite_model_parser.add_argument(
      "--path",
      type=str,
      required=True,
      help="The path to the tflite model file.",
  )
  tflite_model_parser.add_argument(
      "--model_type",
      type=str,
      required=True,
      choices=[
          str(model_type.value).lower().replace("tf_lite_", "")
          for model_type in litertlm_builder.TfLiteModelType
      ],
      help="The type of the tflite model.",
  )
  tflite_model_parser.add_argument(
      "--backend_constraint",
      type=str.lower,
      required=False,
      default=None,
      choices=list(litertlm_builder.Backend),
      help="A list of backend constraints for the tflite model.",
  )
  _add_metadata_arguments(tflite_model_parser)


def _add_tflite_weights_parser(subparsers) -> None:
  """Adds a parser for tflite weights to the subparsers."""
  tflite_weights_parser = subparsers.add_parser(
      "tflite_weights",
      description="Add tflite weights to the LiteRT-LM file.",
      help="Add tflite weights.",
  )
  tflite_weights_parser.add_argument(
      "--path",
      type=str,
      required=True,
      help="The path to the tflite weights file.",
  )
  tflite_weights_parser.add_argument(
      "--model_type",
      type=str,
      required=True,
      choices=[
          str(model_type.value).lower().replace("tf_lite_", "")
          for model_type in litertlm_builder.TfLiteModelType
      ],
      help="The type of the tflite model these weights correspond to.",
  )
  _add_metadata_arguments(tflite_weights_parser)


def _add_sentencepiece_tokenizer_parser(subparsers) -> None:
  """Adds a parser for sentencepiece tokenizer to the subparsers."""
  sp_tokenizer_parser = subparsers.add_parser(
      "sp_tokenizer",
      description="Add a sentencepiece tokenizer to the LiteRT-LM file.",
      help="Add a sentencepiece tokenizer.",
  )
  sp_tokenizer_parser.add_argument(
      "--path",
      type=str,
      required=True,
      help="The path to the sentencepiece tokenizer file.",
  )
  _add_metadata_arguments(sp_tokenizer_parser)


def _add_hf_tokenizer_parser(subparsers) -> None:
  """Adds a parser for huggingface tokenizer to the subparsers."""
  hf_tokenizer_parser = subparsers.add_parser(
      "hf_tokenizer",
      description="Add a huggingface tokenizer to the LiteRT-LM file.",
      help="Add a huggingface tokenizer.",
  )
  hf_tokenizer_parser.add_argument(
      "--path",
      type=str,
      required=True,
      help="The path to the huggingface tokenizer `tokenizer.json` file.",
  )
  _add_metadata_arguments(hf_tokenizer_parser)


def _add_output_path_parser(subparsers) -> None:
  """Adds an argument for the output path to the subparsers."""
  output_path_parser = subparsers.add_parser(
      "output",
      description="The path to the output LiteRT-LM file.",
      help="The path to the output LiteRT-LM file.",
  )
  output_path_parser.add_argument(
      "--path",
      type=str,
      required=True,
      help="The path to the output LiteRT-LM file.",
  )


def _build_parser() -> argparse.ArgumentParser:
  """Builds an argument parser for the litertlm_builder tool."""
  parser = argparse.ArgumentParser(
      description="Build a LiteRT-LM file from input files and metadata."
  )
  subparsers = parser.add_subparsers(dest="command", required=True)
  _add_toml_parser(subparsers)
  _add_system_metadata_parser(subparsers)
  _add_llm_metadata_parser(subparsers)
  _add_tflite_model_parser(subparsers)
  _add_tflite_weights_parser(subparsers)
  _add_sentencepiece_tokenizer_parser(subparsers)
  _add_hf_tokenizer_parser(subparsers)
  _add_output_path_parser(subparsers)

  return parser


def _parse_args(parser: argparse.ArgumentParser) -> list[argparse.Namespace]:
  """Parses the command-line arguments.

  Args:
    parser: The argument parser to use.

  Returns:
    A list of parsed argument namespaces.

  Raises:
    ValueError: If there are unparsed arguments.
  """
  args = sys.argv[1:]
  if len(args) == 1 and args[0] in ["--help", "-h"]:
    print(parser.format_help())
    return []

  # We need to break the arguments into subcommands to ensure overlapping flags
  # are handled correctly. For example, "--path" is a flag for both
  # "llm_metadata" and "output".
  subcommands = []
  current_subcommand = []
  for arg in args:
    if arg in _SUBCOMMANDS:
      if current_subcommand:
        subcommands.append(current_subcommand)
      current_subcommand = [arg]
    else:
      assert current_subcommand, (
          f"No subcommand found for argument: {arg}. Use --help for a list of"
          " subcommands."
      )
      current_subcommand.append(arg)
  if current_subcommand:
    subcommands.append(current_subcommand)

  parsed_args = []
  for subcommand in subcommands:
    parsed, unparsed = parser.parse_known_args(args=subcommand)
    if unparsed:
      raise ValueError(
          f"Failed to parse all arguments. Unparsed args: {unparsed}"
      )
    parsed_args.append(parsed)
  return parsed_args


def _build_system_metadata(
    args: argparse.Namespace,
    builder: litertlm_builder.LitertLmFileBuilder,
) -> None:
  """Builds system metadata from the parsed arguments."""
  if args.str:
    for str_metadata in args.str:
      key, value = str_metadata
      builder.add_system_metadata(
          litertlm_builder.Metadata(
              key=key,
              value=value,
              dtype=litertlm_builder.DType.STRING,
          )
      )
  if args.int:
    for int_metadata in args.int:
      key, value = int_metadata
      builder.add_system_metadata(
          litertlm_builder.Metadata(
              key=key,
              value=int(value),
              dtype=litertlm_builder.DType.INT32,
          )
      )


def _get_metadata_from_args(
    args: argparse.Namespace,
) -> list[litertlm_builder.Metadata] | None:
  """Builds metadata from the parsed arguments."""
  metadata = []
  if hasattr(args, "str_metadata") and args.str_metadata:
    for str_metadata in args.str_metadata:
      key, value = str_metadata
      metadata.append(
          litertlm_builder.Metadata(
              key=key,
              value=value,
              dtype=litertlm_builder.DType.STRING,
          )
      )
  return metadata if metadata else None


def _build_llm_metadata(
    args: argparse.Namespace,
    builder: litertlm_builder.LitertLmFileBuilder,
) -> None:
  """Builds llm metadata from the parsed arguments."""
  metadata = _get_metadata_from_args(args)
  builder.add_llm_metadata(args.path, additional_metadata=metadata)


def _build_tflite_model(
    args: argparse.Namespace,
    builder: litertlm_builder.LitertLmFileBuilder,
) -> None:
  """Builds tflite model from the parsed arguments."""
  metadata = _get_metadata_from_args(args)
  builder.add_tflite_model(
      args.path,
      litertlm_builder.TfLiteModelType.get_enum_from_tf_free_value(
          args.model_type
      ),
      backend_constraint=args.backend_constraint,
      additional_metadata=metadata,
  )


def _build_tflite_weights(
    args: argparse.Namespace,
    builder: litertlm_builder.LitertLmFileBuilder,
) -> None:
  """Builds tflite weights from the parsed arguments."""
  metadata = _get_metadata_from_args(args)
  builder.add_tflite_weights(
      args.path,
      litertlm_builder.TfLiteModelType.get_enum_from_tf_free_value(
          args.model_type
      ),
      additional_metadata=metadata,
  )


def _build_sp_tokenizer(
    args: argparse.Namespace,
    builder: litertlm_builder.LitertLmFileBuilder,
) -> None:
  """Builds sentencepiece tokenizer from the parsed arguments."""
  metadata = _get_metadata_from_args(args)
  builder.add_sentencepiece_tokenizer(args.path, additional_metadata=metadata)


def _build_hf_tokenizer(
    args: argparse.Namespace,
    builder: litertlm_builder.LitertLmFileBuilder,
) -> None:
  """Builds huggingface tokenizer from the parsed arguments."""
  metadata = _get_metadata_from_args(args)
  builder.add_hf_tokenizer(args.path, additional_metadata=metadata)


def _build_litertlm_file(parsed_args: list[argparse.Namespace]) -> None:
  """Builds a LiteRT-LM file from the parsed arguments."""
  if "toml" in [pa.command for pa in parsed_args]:
    toml_path = None
    output_path = None
    for parsed_arg in parsed_args:
      match parsed_arg.command:
        case "output":
          output_path = parsed_arg.path
        case "toml":
          toml_path = parsed_arg.path
        case _:
          raise ValueError(
              "When using TOML, only output and toml are supported."
          )
    assert output_path, "Output path is required."
    assert toml_path, "TOML path is required."
    output_dir = os.path.dirname(output_path)
    if output_dir:
      os.makedirs(output_dir, exist_ok=True)
    with litertlm_core.open_file(output_path, "wb") as f:
      builder = litertlm_builder.LitertLmFileBuilder.from_toml_file(toml_path)
      builder.build(f)
  else:
    builder = litertlm_builder.LitertLmFileBuilder()
    output_path = None
    for parsed_arg in parsed_args:
      match parsed_arg.command:
        case "system_metadata":
          _build_system_metadata(parsed_arg, builder)
        case "llm_metadata":
          _build_llm_metadata(parsed_arg, builder)
        case "tflite_model":
          _build_tflite_model(parsed_arg, builder)
        case "tflite_weights":
          _build_tflite_weights(parsed_arg, builder)
        case "sp_tokenizer":
          _build_sp_tokenizer(parsed_arg, builder)
        case "hf_tokenizer":
          _build_hf_tokenizer(parsed_arg, builder)
        case "output":
          output_path = parsed_arg.path
        case _:
          raise ValueError(f"Unknown subcommand: {parsed_arg.command}")

    assert output_path, "Output path is required."
    output_dir = os.path.dirname(output_path)
    if output_dir:
      os.makedirs(output_dir, exist_ok=True)
    with litertlm_core.open_file(output_path, "wb") as f:
      builder.build(cast(BinaryIO, f))

  print(f"LiteRT-LM file successfully created at {output_path}")


def main(_) -> None:
  parser = _build_parser()
  parsed_args = _parse_args(parser)
  if not parsed_args:
    return
  _build_litertlm_file(parsed_args)


def run():
  """Entry point for console_scripts."""
  app.run(main, sys.argv[:1])

if __name__ == "__main__":
  app.run(main, sys.argv[:1])