#!/usr/bin/env python3 """ Bundle a Strategy-1 KV-cache TFLite + SentencePiece tokenizer into a .litertlm file compatible with Google AI Edge / LiteRT-LM runtime. Embeds: - LlmMetadata proto: Gemma3 model type, 2K max tokens, TranslateGemma Jinja chat template, BOS/EOS/end_of_turn stop tokens - TFLite model (model_type=prefill_decode) - SentencePiece tokenizer Usage: python bundle_litertlm.py \ --tflite /path/to/model.tflite \ --tokenizer /path/to/tokenizer.model \ --output /path/to/output.litertlm \ [--max-tokens 2048] """ import argparse import sys import tempfile from pathlib import Path # Make litert_lm package importable from /tmp/litert-lm-pkg sys.path.insert(0, "/tmp/litert-lm-pkg") from litert_lm_builder import litertlm_builder from litert_lm.runtime.proto import ( llm_metadata_pb2, llm_model_type_pb2, token_pb2, ) # Generic Jinja template for arbitrary language pair translation. # Supports structured XML-like input format: LANGLANGTEXT # Falls back to plain text if XML tags not provided. # Uses only Jinja2 features supported by LiteRT-LM runtime (no .get(), basic string ops). GENERIC_TRANSLATE_TEMPLATE = \ "{{ bos_token }}" \ "{% for message in messages %}" \ "{% if message['role'] == 'user' %}" \ "{% set content = message['content'] | trim %}" \ "{% if '' in content and '' in content and '' in content %}" \ "{% set src_part = content | split('') | last | split('') | first | trim %}" \ "{% set dst_part = content | split('') | last | split('') | first | trim %}" \ "{% set text_part = content | split('') | last | split('') | first | trim %}" \ "user\n" \ "Translate {{ src_part }} to {{ dst_part }}.\n" \ "Produce only the translation, without explanations:\n\n\n" \ "{{ text_part }}\n" \ "\n" \ "{% else %}" \ "user\n" \ "{{ content }}\n" \ "\n" \ "{% endif %}" \ "{% elif message['role'] == 'assistant' %}" \ "model\n" \ "{{ message['content'] | trim }}\n" \ "\n" \ "{% endif %}" \ "{% endfor %}" \ "{% if add_generation_prompt %}" \ "model\n" \ "{% endif %}" TRANSLATE_GEMMA_JINJA_TEMPLATE = GENERIC_TRANSLATE_TEMPLATE # Qwen3 chat template (ChatML format, no-think mode via \n\n prefix) QWEN3_CHAT_TEMPLATE = \ "{% for message in messages %}" \ "{% if message['role'] == 'user' %}" \ "<|im_start|>user\n{{ message['content'] | trim }}<|im_end|>\n" \ "{% elif message['role'] == 'assistant' %}" \ "<|im_start|>assistant\n{{ message['content'] | trim }}<|im_end|>\n" \ "{% elif message['role'] == 'system' %}" \ "<|im_start|>system\n{{ message['content'] | trim }}<|im_end|>\n" \ "{% endif %}" \ "{% endfor %}" \ "{% if add_generation_prompt %}" \ "<|im_start|>assistant\n\n\n\n" \ "{% endif %}" def build_llm_metadata_proto(max_tokens: int, model_type: str = "gemma3") -> bytes: meta = llm_metadata_pb2.LlmMetadata() meta.max_num_tokens = max_tokens if model_type == "qwen3": meta.llm_model_type.qwen3.CopyFrom(llm_model_type_pb2.Qwen3()) # Qwen3 BOS: <|endoftext|> = 151643 meta.start_token.token_ids.ids.append(151643) # Stop tokens: <|im_end|> = 151645, <|endoftext|> = 151643 for tid in [151645, 151643]: st = meta.stop_tokens.add() st.token_ids.ids.append(tid) meta.jinja_prompt_template = QWEN3_CHAT_TEMPLATE else: # Model type: Gemma3 (text-only variant — no vision config needed for TranslateGemma text mode) meta.llm_model_type.gemma3.CopyFrom(llm_model_type_pb2.Gemma3()) # Start token: BOS = token id 2 meta.start_token.token_ids.ids.append(2) # Stop tokens: EOS (id=1) and end_of_turn (id=106) eos = meta.stop_tokens.add() eos.token_ids.ids.append(1) eot = meta.stop_tokens.add() eot.token_ids.ids.append(106) meta.jinja_prompt_template = TRANSLATE_GEMMA_JINJA_TEMPLATE return meta.SerializeToString() def main(): ap = argparse.ArgumentParser(description="Bundle TFLite + tokenizer into .litertlm") ap.add_argument("--tflite", required=True) ap.add_argument("--tokenizer", required=True, help="SentencePiece .model or HF tokenizer.json") ap.add_argument("--tokenizer-type", default="sp", choices=["sp", "hf"], help="sp=SentencePiece (default), hf=HuggingFace tokenizer.json") ap.add_argument("--model-type", default="gemma3", choices=["gemma3", "qwen3"], help="LlmMetadata model type (gemma3=TranslateGemma, qwen3=DictaLM/Qwen3)") ap.add_argument("--output", required=True) ap.add_argument("--max-tokens", type=int, default=2048) ap.add_argument("--quant", default="int8", help="Quantization label for metadata") args = ap.parse_args() tflite_path = Path(args.tflite) tokenizer_path = Path(args.tokenizer) output_path = Path(args.output) if not tflite_path.exists(): print(f"[x] TFLite not found: {tflite_path}", file=sys.stderr) sys.exit(1) if not tokenizer_path.exists(): print(f"[x] Tokenizer not found: {tokenizer_path}", file=sys.stderr) sys.exit(1) output_path.parent.mkdir(parents=True, exist_ok=True) # Write LlmMetadata to temp file meta_bytes = build_llm_metadata_proto(args.max_tokens, model_type=args.model_type) with tempfile.NamedTemporaryFile(suffix=".pb", delete=False) as f: meta_file = Path(f.name) f.write(meta_bytes) print(f"[+] Building .litertlm: {output_path.name}") print(f" TFLite: {tflite_path} ({tflite_path.stat().st_size / 1e9:.2f} GB)") print(f" Tokenizer: {tokenizer_path}") print(f" Max tokens: {args.max_tokens}") Metadata = litertlm_builder.Metadata DType = litertlm_builder.DType builder = litertlm_builder.LitertLmFileBuilder() model_label = "DictaLM-3.0-1.7B" if args.model_type == "qwen3" else "TranslateGemma-4B-IT" builder.add_system_metadata(Metadata(key="model_name", value=f"{model_label}-{args.quant}", dtype=DType.STRING)) builder.add_system_metadata(Metadata(key="authors", value="google", dtype=DType.STRING)) builder.add_system_metadata(Metadata(key="quantization", value=args.quant, dtype=DType.STRING)) builder.add_tflite_model( str(tflite_path), model_type=litertlm_builder.TfLiteModelType.PREFILL_DECODE, ) if args.tokenizer_type == "hf": builder.add_hf_tokenizer(str(tokenizer_path)) else: builder.add_sentencepiece_tokenizer(str(tokenizer_path)) builder.add_llm_metadata(str(meta_file)) with open(output_path, "wb") as f: builder.build(f) meta_file.unlink(missing_ok=True) size = output_path.stat().st_size print(f"[+] Written: {output_path} ({size / 1e9:.2f} GB)") if __name__ == "__main__": main()