File size: 6,953 Bytes
3d6b6ce
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e468f0d
 
 
 
 
 
3bb2de3
 
 
e468f0d
 
 
 
 
 
 
 
 
 
 
3bb2de3
e468f0d
3bb2de3
e468f0d
3bb2de3
 
e468f0d
3bb2de3
 
 
 
 
 
3d6b6ce
e468f0d
3d6b6ce
e468f0d
 
 
 
 
 
 
 
 
 
 
 
 
 
3d6b6ce
 
e468f0d
 
 
3d6b6ce
e468f0d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3d6b6ce
 
 
 
 
 
 
e468f0d
 
 
 
 
3d6b6ce
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e468f0d
3d6b6ce
 
 
 
 
 
 
 
 
 
 
 
 
e468f0d
 
3d6b6ce
 
 
 
 
 
 
e468f0d
 
 
 
3d6b6ce
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
#!/usr/bin/env python3
"""
Bundle a Strategy-1 KV-cache TFLite + SentencePiece tokenizer into a
.litertlm file compatible with Google AI Edge / LiteRT-LM runtime.

Embeds:
  - LlmMetadata proto: Gemma3 model type, 2K max tokens, TranslateGemma
    Jinja chat template, BOS/EOS/end_of_turn stop tokens
  - TFLite model  (model_type=prefill_decode)
  - SentencePiece tokenizer

Usage:
  python bundle_litertlm.py \
    --tflite /path/to/model.tflite \
    --tokenizer /path/to/tokenizer.model \
    --output /path/to/output.litertlm \
    [--max-tokens 2048]
"""

import argparse
import sys
import tempfile
from pathlib import Path

# Make litert_lm package importable from /tmp/litert-lm-pkg
sys.path.insert(0, "/tmp/litert-lm-pkg")

from litert_lm_builder import litertlm_builder
from litert_lm.runtime.proto import (
    llm_metadata_pb2,
    llm_model_type_pb2,
    token_pb2,
)


# Generic Jinja template for arbitrary language pair translation.
# Supports structured XML-like input format: <src>LANG</src><dst>LANG</dst><text>TEXT</text>
# Falls back to plain text if XML tags not provided.
# Uses only Jinja2 features supported by LiteRT-LM runtime (no .get(), basic string ops).

GENERIC_TRANSLATE_TEMPLATE = \
"{{ bos_token }}" \
"{% for message in messages %}" \
"{% if message['role'] == 'user' %}" \
"{% set content = message['content'] | trim %}" \
"{% if '<src>' in content and '<dst>' in content and '<text>' in content %}" \
"{% set src_part = content | split('<src>') | last | split('</src>') | first | trim %}" \
"{% set dst_part = content | split('<dst>') | last | split('</dst>') | first | trim %}" \
"{% set text_part = content | split('<text>') | last | split('</text>') | first | trim %}" \
"<start_of_turn>user\n" \
"Translate {{ src_part }} to {{ dst_part }}.\n" \
"Produce only the translation, without explanations:\n\n\n" \
"{{ text_part }}\n" \
"<end_of_turn>\n" \
"{% else %}" \
"<start_of_turn>user\n" \
"{{ content }}\n" \
"<end_of_turn>\n" \
"{% endif %}" \
"{% elif message['role'] == 'assistant' %}" \
"<start_of_turn>model\n" \
"{{ message['content'] | trim }}\n" \
"<end_of_turn>\n" \
"{% endif %}" \
"{% endfor %}" \
"{% if add_generation_prompt %}" \
"<start_of_turn>model\n" \
"{% endif %}"

TRANSLATE_GEMMA_JINJA_TEMPLATE = GENERIC_TRANSLATE_TEMPLATE

# Qwen3 chat template (ChatML format, no-think mode via <think>\n\n</think> prefix)
QWEN3_CHAT_TEMPLATE = \
"{% for message in messages %}" \
"{% if message['role'] == 'user' %}" \
"<|im_start|>user\n{{ message['content'] | trim }}<|im_end|>\n" \
"{% elif message['role'] == 'assistant' %}" \
"<|im_start|>assistant\n{{ message['content'] | trim }}<|im_end|>\n" \
"{% elif message['role'] == 'system' %}" \
"<|im_start|>system\n{{ message['content'] | trim }}<|im_end|>\n" \
"{% endif %}" \
"{% endfor %}" \
"{% if add_generation_prompt %}" \
"<|im_start|>assistant\n<think>\n\n</think>\n" \
"{% endif %}"


def build_llm_metadata_proto(max_tokens: int, model_type: str = "gemma3") -> bytes:
    meta = llm_metadata_pb2.LlmMetadata()
    meta.max_num_tokens = max_tokens

    if model_type == "qwen3":
        meta.llm_model_type.qwen3.CopyFrom(llm_model_type_pb2.Qwen3())
        # Qwen3 BOS: <|endoftext|> = 151643
        meta.start_token.token_ids.ids.append(151643)
        # Stop tokens: <|im_end|> = 151645, <|endoftext|> = 151643
        for tid in [151645, 151643]:
            st = meta.stop_tokens.add()
            st.token_ids.ids.append(tid)
        meta.jinja_prompt_template = QWEN3_CHAT_TEMPLATE
    else:
        # Model type: Gemma3 (text-only variant — no vision config needed for TranslateGemma text mode)
        meta.llm_model_type.gemma3.CopyFrom(llm_model_type_pb2.Gemma3())
        # Start token: BOS = token id 2
        meta.start_token.token_ids.ids.append(2)
        # Stop tokens: EOS (id=1) and end_of_turn (id=106)
        eos = meta.stop_tokens.add()
        eos.token_ids.ids.append(1)
        eot = meta.stop_tokens.add()
        eot.token_ids.ids.append(106)
        meta.jinja_prompt_template = TRANSLATE_GEMMA_JINJA_TEMPLATE

    return meta.SerializeToString()


def main():
    ap = argparse.ArgumentParser(description="Bundle TFLite + tokenizer into .litertlm")
    ap.add_argument("--tflite", required=True)
    ap.add_argument("--tokenizer", required=True, help="SentencePiece .model or HF tokenizer.json")
    ap.add_argument("--tokenizer-type", default="sp", choices=["sp", "hf"],
                    help="sp=SentencePiece (default), hf=HuggingFace tokenizer.json")
    ap.add_argument("--model-type", default="gemma3", choices=["gemma3", "qwen3"],
                    help="LlmMetadata model type (gemma3=TranslateGemma, qwen3=DictaLM/Qwen3)")
    ap.add_argument("--output", required=True)
    ap.add_argument("--max-tokens", type=int, default=2048)
    ap.add_argument("--quant", default="int8", help="Quantization label for metadata")
    args = ap.parse_args()

    tflite_path = Path(args.tflite)
    tokenizer_path = Path(args.tokenizer)
    output_path = Path(args.output)

    if not tflite_path.exists():
        print(f"[x] TFLite not found: {tflite_path}", file=sys.stderr)
        sys.exit(1)
    if not tokenizer_path.exists():
        print(f"[x] Tokenizer not found: {tokenizer_path}", file=sys.stderr)
        sys.exit(1)

    output_path.parent.mkdir(parents=True, exist_ok=True)

    # Write LlmMetadata to temp file
    meta_bytes = build_llm_metadata_proto(args.max_tokens, model_type=args.model_type)
    with tempfile.NamedTemporaryFile(suffix=".pb", delete=False) as f:
        meta_file = Path(f.name)
        f.write(meta_bytes)

    print(f"[+] Building .litertlm: {output_path.name}")
    print(f"    TFLite: {tflite_path} ({tflite_path.stat().st_size / 1e9:.2f} GB)")
    print(f"    Tokenizer: {tokenizer_path}")
    print(f"    Max tokens: {args.max_tokens}")

    Metadata = litertlm_builder.Metadata
    DType = litertlm_builder.DType

    builder = litertlm_builder.LitertLmFileBuilder()
    model_label = "DictaLM-3.0-1.7B" if args.model_type == "qwen3" else "TranslateGemma-4B-IT"
    builder.add_system_metadata(Metadata(key="model_name", value=f"{model_label}-{args.quant}", dtype=DType.STRING))
    builder.add_system_metadata(Metadata(key="authors", value="google", dtype=DType.STRING))
    builder.add_system_metadata(Metadata(key="quantization", value=args.quant, dtype=DType.STRING))

    builder.add_tflite_model(
        str(tflite_path),
        model_type=litertlm_builder.TfLiteModelType.PREFILL_DECODE,
    )
    if args.tokenizer_type == "hf":
        builder.add_hf_tokenizer(str(tokenizer_path))
    else:
        builder.add_sentencepiece_tokenizer(str(tokenizer_path))
    builder.add_llm_metadata(str(meta_file))

    with open(output_path, "wb") as f:
        builder.build(f)
    meta_file.unlink(missing_ok=True)

    size = output_path.stat().st_size
    print(f"[+] Written: {output_path} ({size / 1e9:.2f} GB)")


if __name__ == "__main__":
    main()