File size: 6,953 Bytes
3d6b6ce e468f0d 3bb2de3 e468f0d 3bb2de3 e468f0d 3bb2de3 e468f0d 3bb2de3 e468f0d 3bb2de3 3d6b6ce e468f0d 3d6b6ce e468f0d 3d6b6ce e468f0d 3d6b6ce e468f0d 3d6b6ce e468f0d 3d6b6ce e468f0d 3d6b6ce e468f0d 3d6b6ce e468f0d 3d6b6ce | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 | #!/usr/bin/env python3
"""
Bundle a Strategy-1 KV-cache TFLite + SentencePiece tokenizer into a
.litertlm file compatible with Google AI Edge / LiteRT-LM runtime.
Embeds:
- LlmMetadata proto: Gemma3 model type, 2K max tokens, TranslateGemma
Jinja chat template, BOS/EOS/end_of_turn stop tokens
- TFLite model (model_type=prefill_decode)
- SentencePiece tokenizer
Usage:
python bundle_litertlm.py \
--tflite /path/to/model.tflite \
--tokenizer /path/to/tokenizer.model \
--output /path/to/output.litertlm \
[--max-tokens 2048]
"""
import argparse
import sys
import tempfile
from pathlib import Path
# Make litert_lm package importable from /tmp/litert-lm-pkg
sys.path.insert(0, "/tmp/litert-lm-pkg")
from litert_lm_builder import litertlm_builder
from litert_lm.runtime.proto import (
llm_metadata_pb2,
llm_model_type_pb2,
token_pb2,
)
# Generic Jinja template for arbitrary language pair translation.
# Supports structured XML-like input format: <src>LANG</src><dst>LANG</dst><text>TEXT</text>
# Falls back to plain text if XML tags not provided.
# Uses only Jinja2 features supported by LiteRT-LM runtime (no .get(), basic string ops).
GENERIC_TRANSLATE_TEMPLATE = \
"{{ bos_token }}" \
"{% for message in messages %}" \
"{% if message['role'] == 'user' %}" \
"{% set content = message['content'] | trim %}" \
"{% if '<src>' in content and '<dst>' in content and '<text>' in content %}" \
"{% set src_part = content | split('<src>') | last | split('</src>') | first | trim %}" \
"{% set dst_part = content | split('<dst>') | last | split('</dst>') | first | trim %}" \
"{% set text_part = content | split('<text>') | last | split('</text>') | first | trim %}" \
"<start_of_turn>user\n" \
"Translate {{ src_part }} to {{ dst_part }}.\n" \
"Produce only the translation, without explanations:\n\n\n" \
"{{ text_part }}\n" \
"<end_of_turn>\n" \
"{% else %}" \
"<start_of_turn>user\n" \
"{{ content }}\n" \
"<end_of_turn>\n" \
"{% endif %}" \
"{% elif message['role'] == 'assistant' %}" \
"<start_of_turn>model\n" \
"{{ message['content'] | trim }}\n" \
"<end_of_turn>\n" \
"{% endif %}" \
"{% endfor %}" \
"{% if add_generation_prompt %}" \
"<start_of_turn>model\n" \
"{% endif %}"
TRANSLATE_GEMMA_JINJA_TEMPLATE = GENERIC_TRANSLATE_TEMPLATE
# Qwen3 chat template (ChatML format, no-think mode via <think>\n\n</think> prefix)
QWEN3_CHAT_TEMPLATE = \
"{% for message in messages %}" \
"{% if message['role'] == 'user' %}" \
"<|im_start|>user\n{{ message['content'] | trim }}<|im_end|>\n" \
"{% elif message['role'] == 'assistant' %}" \
"<|im_start|>assistant\n{{ message['content'] | trim }}<|im_end|>\n" \
"{% elif message['role'] == 'system' %}" \
"<|im_start|>system\n{{ message['content'] | trim }}<|im_end|>\n" \
"{% endif %}" \
"{% endfor %}" \
"{% if add_generation_prompt %}" \
"<|im_start|>assistant\n<think>\n\n</think>\n" \
"{% endif %}"
def build_llm_metadata_proto(max_tokens: int, model_type: str = "gemma3") -> bytes:
meta = llm_metadata_pb2.LlmMetadata()
meta.max_num_tokens = max_tokens
if model_type == "qwen3":
meta.llm_model_type.qwen3.CopyFrom(llm_model_type_pb2.Qwen3())
# Qwen3 BOS: <|endoftext|> = 151643
meta.start_token.token_ids.ids.append(151643)
# Stop tokens: <|im_end|> = 151645, <|endoftext|> = 151643
for tid in [151645, 151643]:
st = meta.stop_tokens.add()
st.token_ids.ids.append(tid)
meta.jinja_prompt_template = QWEN3_CHAT_TEMPLATE
else:
# Model type: Gemma3 (text-only variant — no vision config needed for TranslateGemma text mode)
meta.llm_model_type.gemma3.CopyFrom(llm_model_type_pb2.Gemma3())
# Start token: BOS = token id 2
meta.start_token.token_ids.ids.append(2)
# Stop tokens: EOS (id=1) and end_of_turn (id=106)
eos = meta.stop_tokens.add()
eos.token_ids.ids.append(1)
eot = meta.stop_tokens.add()
eot.token_ids.ids.append(106)
meta.jinja_prompt_template = TRANSLATE_GEMMA_JINJA_TEMPLATE
return meta.SerializeToString()
def main():
ap = argparse.ArgumentParser(description="Bundle TFLite + tokenizer into .litertlm")
ap.add_argument("--tflite", required=True)
ap.add_argument("--tokenizer", required=True, help="SentencePiece .model or HF tokenizer.json")
ap.add_argument("--tokenizer-type", default="sp", choices=["sp", "hf"],
help="sp=SentencePiece (default), hf=HuggingFace tokenizer.json")
ap.add_argument("--model-type", default="gemma3", choices=["gemma3", "qwen3"],
help="LlmMetadata model type (gemma3=TranslateGemma, qwen3=DictaLM/Qwen3)")
ap.add_argument("--output", required=True)
ap.add_argument("--max-tokens", type=int, default=2048)
ap.add_argument("--quant", default="int8", help="Quantization label for metadata")
args = ap.parse_args()
tflite_path = Path(args.tflite)
tokenizer_path = Path(args.tokenizer)
output_path = Path(args.output)
if not tflite_path.exists():
print(f"[x] TFLite not found: {tflite_path}", file=sys.stderr)
sys.exit(1)
if not tokenizer_path.exists():
print(f"[x] Tokenizer not found: {tokenizer_path}", file=sys.stderr)
sys.exit(1)
output_path.parent.mkdir(parents=True, exist_ok=True)
# Write LlmMetadata to temp file
meta_bytes = build_llm_metadata_proto(args.max_tokens, model_type=args.model_type)
with tempfile.NamedTemporaryFile(suffix=".pb", delete=False) as f:
meta_file = Path(f.name)
f.write(meta_bytes)
print(f"[+] Building .litertlm: {output_path.name}")
print(f" TFLite: {tflite_path} ({tflite_path.stat().st_size / 1e9:.2f} GB)")
print(f" Tokenizer: {tokenizer_path}")
print(f" Max tokens: {args.max_tokens}")
Metadata = litertlm_builder.Metadata
DType = litertlm_builder.DType
builder = litertlm_builder.LitertLmFileBuilder()
model_label = "DictaLM-3.0-1.7B" if args.model_type == "qwen3" else "TranslateGemma-4B-IT"
builder.add_system_metadata(Metadata(key="model_name", value=f"{model_label}-{args.quant}", dtype=DType.STRING))
builder.add_system_metadata(Metadata(key="authors", value="google", dtype=DType.STRING))
builder.add_system_metadata(Metadata(key="quantization", value=args.quant, dtype=DType.STRING))
builder.add_tflite_model(
str(tflite_path),
model_type=litertlm_builder.TfLiteModelType.PREFILL_DECODE,
)
if args.tokenizer_type == "hf":
builder.add_hf_tokenizer(str(tokenizer_path))
else:
builder.add_sentencepiece_tokenizer(str(tokenizer_path))
builder.add_llm_metadata(str(meta_file))
with open(output_path, "wb") as f:
builder.build(f)
meta_file.unlink(missing_ok=True)
size = output_path.stat().st_size
print(f"[+] Written: {output_path} ({size / 1e9:.2f} GB)")
if __name__ == "__main__":
main()
|