| |
| """ |
| Bundle a Strategy-1 KV-cache TFLite + SentencePiece tokenizer into a |
| .litertlm file compatible with Google AI Edge / LiteRT-LM runtime. |
| |
| Embeds: |
| - LlmMetadata proto: Gemma3 model type, 2K max tokens, TranslateGemma |
| Jinja chat template, BOS/EOS/end_of_turn stop tokens |
| - TFLite model (model_type=prefill_decode) |
| - SentencePiece tokenizer |
| |
| Usage: |
| python bundle_litertlm.py \ |
| --tflite /path/to/model.tflite \ |
| --tokenizer /path/to/tokenizer.model \ |
| --output /path/to/output.litertlm \ |
| [--max-tokens 2048] |
| """ |
|
|
| import argparse |
| import sys |
| import tempfile |
| from pathlib import Path |
|
|
| |
| sys.path.insert(0, "/tmp/litert-lm-pkg") |
|
|
| from litert_lm_builder import litertlm_builder |
| from litert_lm.runtime.proto import ( |
| llm_metadata_pb2, |
| llm_model_type_pb2, |
| token_pb2, |
| ) |
|
|
|
|
| |
| |
| |
| |
|
|
| GENERIC_TRANSLATE_TEMPLATE = \ |
| "{{ bos_token }}" \ |
| "{% for message in messages %}" \ |
| "{% if message['role'] == 'user' %}" \ |
| "{% set content = message['content'] | trim %}" \ |
| "{% if '<src>' in content and '<dst>' in content and '<text>' in content %}" \ |
| "{% set src_part = content | split('<src>') | last | split('</src>') | first | trim %}" \ |
| "{% set dst_part = content | split('<dst>') | last | split('</dst>') | first | trim %}" \ |
| "{% set text_part = content | split('<text>') | last | split('</text>') | first | trim %}" \ |
| "<start_of_turn>user\n" \ |
| "Translate {{ src_part }} to {{ dst_part }}.\n" \ |
| "Produce only the translation, without explanations:\n\n\n" \ |
| "{{ text_part }}\n" \ |
| "<end_of_turn>\n" \ |
| "{% else %}" \ |
| "<start_of_turn>user\n" \ |
| "{{ content }}\n" \ |
| "<end_of_turn>\n" \ |
| "{% endif %}" \ |
| "{% elif message['role'] == 'assistant' %}" \ |
| "<start_of_turn>model\n" \ |
| "{{ message['content'] | trim }}\n" \ |
| "<end_of_turn>\n" \ |
| "{% endif %}" \ |
| "{% endfor %}" \ |
| "{% if add_generation_prompt %}" \ |
| "<start_of_turn>model\n" \ |
| "{% endif %}" |
|
|
| TRANSLATE_GEMMA_JINJA_TEMPLATE = GENERIC_TRANSLATE_TEMPLATE |
|
|
| |
| QWEN3_CHAT_TEMPLATE = \ |
| "{% for message in messages %}" \ |
| "{% if message['role'] == 'user' %}" \ |
| "<|im_start|>user\n{{ message['content'] | trim }}<|im_end|>\n" \ |
| "{% elif message['role'] == 'assistant' %}" \ |
| "<|im_start|>assistant\n{{ message['content'] | trim }}<|im_end|>\n" \ |
| "{% elif message['role'] == 'system' %}" \ |
| "<|im_start|>system\n{{ message['content'] | trim }}<|im_end|>\n" \ |
| "{% endif %}" \ |
| "{% endfor %}" \ |
| "{% if add_generation_prompt %}" \ |
| "<|im_start|>assistant\n<think>\n\n</think>\n" \ |
| "{% endif %}" |
|
|
|
|
| def build_llm_metadata_proto(max_tokens: int, model_type: str = "gemma3") -> bytes: |
| meta = llm_metadata_pb2.LlmMetadata() |
| meta.max_num_tokens = max_tokens |
|
|
| if model_type == "qwen3": |
| meta.llm_model_type.qwen3.CopyFrom(llm_model_type_pb2.Qwen3()) |
| |
| meta.start_token.token_ids.ids.append(151643) |
| |
| for tid in [151645, 151643]: |
| st = meta.stop_tokens.add() |
| st.token_ids.ids.append(tid) |
| meta.jinja_prompt_template = QWEN3_CHAT_TEMPLATE |
| else: |
| |
| meta.llm_model_type.gemma3.CopyFrom(llm_model_type_pb2.Gemma3()) |
| |
| meta.start_token.token_ids.ids.append(2) |
| |
| eos = meta.stop_tokens.add() |
| eos.token_ids.ids.append(1) |
| eot = meta.stop_tokens.add() |
| eot.token_ids.ids.append(106) |
| meta.jinja_prompt_template = TRANSLATE_GEMMA_JINJA_TEMPLATE |
|
|
| return meta.SerializeToString() |
|
|
|
|
| def main(): |
| ap = argparse.ArgumentParser(description="Bundle TFLite + tokenizer into .litertlm") |
| ap.add_argument("--tflite", required=True) |
| ap.add_argument("--tokenizer", required=True, help="SentencePiece .model or HF tokenizer.json") |
| ap.add_argument("--tokenizer-type", default="sp", choices=["sp", "hf"], |
| help="sp=SentencePiece (default), hf=HuggingFace tokenizer.json") |
| ap.add_argument("--model-type", default="gemma3", choices=["gemma3", "qwen3"], |
| help="LlmMetadata model type (gemma3=TranslateGemma, qwen3=DictaLM/Qwen3)") |
| ap.add_argument("--output", required=True) |
| ap.add_argument("--max-tokens", type=int, default=2048) |
| ap.add_argument("--quant", default="int8", help="Quantization label for metadata") |
| args = ap.parse_args() |
|
|
| tflite_path = Path(args.tflite) |
| tokenizer_path = Path(args.tokenizer) |
| output_path = Path(args.output) |
|
|
| if not tflite_path.exists(): |
| print(f"[x] TFLite not found: {tflite_path}", file=sys.stderr) |
| sys.exit(1) |
| if not tokenizer_path.exists(): |
| print(f"[x] Tokenizer not found: {tokenizer_path}", file=sys.stderr) |
| sys.exit(1) |
|
|
| output_path.parent.mkdir(parents=True, exist_ok=True) |
|
|
| |
| meta_bytes = build_llm_metadata_proto(args.max_tokens, model_type=args.model_type) |
| with tempfile.NamedTemporaryFile(suffix=".pb", delete=False) as f: |
| meta_file = Path(f.name) |
| f.write(meta_bytes) |
|
|
| print(f"[+] Building .litertlm: {output_path.name}") |
| print(f" TFLite: {tflite_path} ({tflite_path.stat().st_size / 1e9:.2f} GB)") |
| print(f" Tokenizer: {tokenizer_path}") |
| print(f" Max tokens: {args.max_tokens}") |
|
|
| Metadata = litertlm_builder.Metadata |
| DType = litertlm_builder.DType |
|
|
| builder = litertlm_builder.LitertLmFileBuilder() |
| model_label = "DictaLM-3.0-1.7B" if args.model_type == "qwen3" else "TranslateGemma-4B-IT" |
| builder.add_system_metadata(Metadata(key="model_name", value=f"{model_label}-{args.quant}", dtype=DType.STRING)) |
| builder.add_system_metadata(Metadata(key="authors", value="google", dtype=DType.STRING)) |
| builder.add_system_metadata(Metadata(key="quantization", value=args.quant, dtype=DType.STRING)) |
|
|
| builder.add_tflite_model( |
| str(tflite_path), |
| model_type=litertlm_builder.TfLiteModelType.PREFILL_DECODE, |
| ) |
| if args.tokenizer_type == "hf": |
| builder.add_hf_tokenizer(str(tokenizer_path)) |
| else: |
| builder.add_sentencepiece_tokenizer(str(tokenizer_path)) |
| builder.add_llm_metadata(str(meta_file)) |
|
|
| with open(output_path, "wb") as f: |
| builder.build(f) |
| meta_file.unlink(missing_ok=True) |
|
|
| size = output_path.stat().st_size |
| print(f"[+] Written: {output_path} ({size / 1e9:.2f} GB)") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|