barakplasma's picture
Upload scripts/bundle_litertlm.py with huggingface_hub
e468f0d verified
#!/usr/bin/env python3
"""
Bundle a Strategy-1 KV-cache TFLite + SentencePiece tokenizer into a
.litertlm file compatible with Google AI Edge / LiteRT-LM runtime.
Embeds:
- LlmMetadata proto: Gemma3 model type, 2K max tokens, TranslateGemma
Jinja chat template, BOS/EOS/end_of_turn stop tokens
- TFLite model (model_type=prefill_decode)
- SentencePiece tokenizer
Usage:
python bundle_litertlm.py \
--tflite /path/to/model.tflite \
--tokenizer /path/to/tokenizer.model \
--output /path/to/output.litertlm \
[--max-tokens 2048]
"""
import argparse
import sys
import tempfile
from pathlib import Path
# Make litert_lm package importable from /tmp/litert-lm-pkg
sys.path.insert(0, "/tmp/litert-lm-pkg")
from litert_lm_builder import litertlm_builder
from litert_lm.runtime.proto import (
llm_metadata_pb2,
llm_model_type_pb2,
token_pb2,
)
# Generic Jinja template for arbitrary language pair translation.
# Supports structured XML-like input format: <src>LANG</src><dst>LANG</dst><text>TEXT</text>
# Falls back to plain text if XML tags not provided.
# Uses only Jinja2 features supported by LiteRT-LM runtime (no .get(), basic string ops).
GENERIC_TRANSLATE_TEMPLATE = \
"{{ bos_token }}" \
"{% for message in messages %}" \
"{% if message['role'] == 'user' %}" \
"{% set content = message['content'] | trim %}" \
"{% if '<src>' in content and '<dst>' in content and '<text>' in content %}" \
"{% set src_part = content | split('<src>') | last | split('</src>') | first | trim %}" \
"{% set dst_part = content | split('<dst>') | last | split('</dst>') | first | trim %}" \
"{% set text_part = content | split('<text>') | last | split('</text>') | first | trim %}" \
"<start_of_turn>user\n" \
"Translate {{ src_part }} to {{ dst_part }}.\n" \
"Produce only the translation, without explanations:\n\n\n" \
"{{ text_part }}\n" \
"<end_of_turn>\n" \
"{% else %}" \
"<start_of_turn>user\n" \
"{{ content }}\n" \
"<end_of_turn>\n" \
"{% endif %}" \
"{% elif message['role'] == 'assistant' %}" \
"<start_of_turn>model\n" \
"{{ message['content'] | trim }}\n" \
"<end_of_turn>\n" \
"{% endif %}" \
"{% endfor %}" \
"{% if add_generation_prompt %}" \
"<start_of_turn>model\n" \
"{% endif %}"
TRANSLATE_GEMMA_JINJA_TEMPLATE = GENERIC_TRANSLATE_TEMPLATE
# Qwen3 chat template (ChatML format, no-think mode via <think>\n\n</think> prefix)
QWEN3_CHAT_TEMPLATE = \
"{% for message in messages %}" \
"{% if message['role'] == 'user' %}" \
"<|im_start|>user\n{{ message['content'] | trim }}<|im_end|>\n" \
"{% elif message['role'] == 'assistant' %}" \
"<|im_start|>assistant\n{{ message['content'] | trim }}<|im_end|>\n" \
"{% elif message['role'] == 'system' %}" \
"<|im_start|>system\n{{ message['content'] | trim }}<|im_end|>\n" \
"{% endif %}" \
"{% endfor %}" \
"{% if add_generation_prompt %}" \
"<|im_start|>assistant\n<think>\n\n</think>\n" \
"{% endif %}"
def build_llm_metadata_proto(max_tokens: int, model_type: str = "gemma3") -> bytes:
meta = llm_metadata_pb2.LlmMetadata()
meta.max_num_tokens = max_tokens
if model_type == "qwen3":
meta.llm_model_type.qwen3.CopyFrom(llm_model_type_pb2.Qwen3())
# Qwen3 BOS: <|endoftext|> = 151643
meta.start_token.token_ids.ids.append(151643)
# Stop tokens: <|im_end|> = 151645, <|endoftext|> = 151643
for tid in [151645, 151643]:
st = meta.stop_tokens.add()
st.token_ids.ids.append(tid)
meta.jinja_prompt_template = QWEN3_CHAT_TEMPLATE
else:
# Model type: Gemma3 (text-only variant — no vision config needed for TranslateGemma text mode)
meta.llm_model_type.gemma3.CopyFrom(llm_model_type_pb2.Gemma3())
# Start token: BOS = token id 2
meta.start_token.token_ids.ids.append(2)
# Stop tokens: EOS (id=1) and end_of_turn (id=106)
eos = meta.stop_tokens.add()
eos.token_ids.ids.append(1)
eot = meta.stop_tokens.add()
eot.token_ids.ids.append(106)
meta.jinja_prompt_template = TRANSLATE_GEMMA_JINJA_TEMPLATE
return meta.SerializeToString()
def main():
ap = argparse.ArgumentParser(description="Bundle TFLite + tokenizer into .litertlm")
ap.add_argument("--tflite", required=True)
ap.add_argument("--tokenizer", required=True, help="SentencePiece .model or HF tokenizer.json")
ap.add_argument("--tokenizer-type", default="sp", choices=["sp", "hf"],
help="sp=SentencePiece (default), hf=HuggingFace tokenizer.json")
ap.add_argument("--model-type", default="gemma3", choices=["gemma3", "qwen3"],
help="LlmMetadata model type (gemma3=TranslateGemma, qwen3=DictaLM/Qwen3)")
ap.add_argument("--output", required=True)
ap.add_argument("--max-tokens", type=int, default=2048)
ap.add_argument("--quant", default="int8", help="Quantization label for metadata")
args = ap.parse_args()
tflite_path = Path(args.tflite)
tokenizer_path = Path(args.tokenizer)
output_path = Path(args.output)
if not tflite_path.exists():
print(f"[x] TFLite not found: {tflite_path}", file=sys.stderr)
sys.exit(1)
if not tokenizer_path.exists():
print(f"[x] Tokenizer not found: {tokenizer_path}", file=sys.stderr)
sys.exit(1)
output_path.parent.mkdir(parents=True, exist_ok=True)
# Write LlmMetadata to temp file
meta_bytes = build_llm_metadata_proto(args.max_tokens, model_type=args.model_type)
with tempfile.NamedTemporaryFile(suffix=".pb", delete=False) as f:
meta_file = Path(f.name)
f.write(meta_bytes)
print(f"[+] Building .litertlm: {output_path.name}")
print(f" TFLite: {tflite_path} ({tflite_path.stat().st_size / 1e9:.2f} GB)")
print(f" Tokenizer: {tokenizer_path}")
print(f" Max tokens: {args.max_tokens}")
Metadata = litertlm_builder.Metadata
DType = litertlm_builder.DType
builder = litertlm_builder.LitertLmFileBuilder()
model_label = "DictaLM-3.0-1.7B" if args.model_type == "qwen3" else "TranslateGemma-4B-IT"
builder.add_system_metadata(Metadata(key="model_name", value=f"{model_label}-{args.quant}", dtype=DType.STRING))
builder.add_system_metadata(Metadata(key="authors", value="google", dtype=DType.STRING))
builder.add_system_metadata(Metadata(key="quantization", value=args.quant, dtype=DType.STRING))
builder.add_tflite_model(
str(tflite_path),
model_type=litertlm_builder.TfLiteModelType.PREFILL_DECODE,
)
if args.tokenizer_type == "hf":
builder.add_hf_tokenizer(str(tokenizer_path))
else:
builder.add_sentencepiece_tokenizer(str(tokenizer_path))
builder.add_llm_metadata(str(meta_file))
with open(output_path, "wb") as f:
builder.build(f)
meta_file.unlink(missing_ok=True)
size = output_path.stat().st_size
print(f"[+] Written: {output_path} ({size / 1e9:.2f} GB)")
if __name__ == "__main__":
main()