barakplasma commited on
Commit
e468f0d
·
verified ·
1 Parent(s): cd2eec1

Upload scripts/bundle_litertlm.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. scripts/bundle_litertlm.py +70 -30
scripts/bundle_litertlm.py CHANGED
@@ -33,23 +33,33 @@ from litert_lm.runtime.proto import (
33
  )
34
 
35
 
36
- # Simple Jinja template compatible with LiteRT-LM runtime (no .get(), no complex tests).
37
- # Handles plain text input from Google AI Edge Gallery.
38
- # Uses the exact prompt format TranslateGemma was trained with (en→es default).
39
- # Users who need other language pairs should prefix their message with the pair,
40
- # e.g. "Translate English to French:\n\nHello"
41
- TRANSLATE_GEMMA_JINJA_TEMPLATE = \
42
  "{{ bos_token }}" \
43
  "{% for message in messages %}" \
44
  "{% if message['role'] == 'user' %}" \
 
 
 
 
 
 
 
 
 
 
 
45
  "<start_of_turn>user\n" \
46
- "You are a professional translator. " \
47
- "Produce only the translation of the following text, without any additional explanations or commentary:\n\n\n" \
48
- "{{ message['content'] | trim }}" \
49
  "<end_of_turn>\n" \
 
50
  "{% elif message['role'] == 'assistant' %}" \
51
  "<start_of_turn>model\n" \
52
- "{{ message['content'] | trim }}" \
53
  "<end_of_turn>\n" \
54
  "{% endif %}" \
55
  "{% endfor %}" \
@@ -57,26 +67,48 @@ TRANSLATE_GEMMA_JINJA_TEMPLATE = \
57
  "<start_of_turn>model\n" \
58
  "{% endif %}"
59
 
 
60
 
61
- def build_llm_metadata_proto(max_tokens: int) -> bytes:
62
- meta = llm_metadata_pb2.LlmMetadata()
63
- meta.max_num_tokens = max_tokens
64
-
65
- # Model type: Gemma3 (text-only variant — no vision config needed for TranslateGemma text mode)
66
- meta.llm_model_type.gemma3.CopyFrom(llm_model_type_pb2.Gemma3())
67
-
68
- # Start token: BOS = token id 2
69
- meta.start_token.token_ids.ids.append(2)
 
 
 
 
 
70
 
71
- # Stop tokens: EOS (id=1) and end_of_turn (id=106)
72
- eos = meta.stop_tokens.add()
73
- eos.token_ids.ids.append(1)
74
 
75
- eot = meta.stop_tokens.add()
76
- eot.token_ids.ids.append(106)
 
77
 
78
- # Embed the Jinja template
79
- meta.jinja_prompt_template = TRANSLATE_GEMMA_JINJA_TEMPLATE
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
 
81
  return meta.SerializeToString()
82
 
@@ -84,7 +116,11 @@ def build_llm_metadata_proto(max_tokens: int) -> bytes:
84
  def main():
85
  ap = argparse.ArgumentParser(description="Bundle TFLite + tokenizer into .litertlm")
86
  ap.add_argument("--tflite", required=True)
87
- ap.add_argument("--tokenizer", required=True, help="SentencePiece .model file")
 
 
 
 
88
  ap.add_argument("--output", required=True)
89
  ap.add_argument("--max-tokens", type=int, default=2048)
90
  ap.add_argument("--quant", default="int8", help="Quantization label for metadata")
@@ -104,7 +140,7 @@ def main():
104
  output_path.parent.mkdir(parents=True, exist_ok=True)
105
 
106
  # Write LlmMetadata to temp file
107
- meta_bytes = build_llm_metadata_proto(args.max_tokens)
108
  with tempfile.NamedTemporaryFile(suffix=".pb", delete=False) as f:
109
  meta_file = Path(f.name)
110
  f.write(meta_bytes)
@@ -118,7 +154,8 @@ def main():
118
  DType = litertlm_builder.DType
119
 
120
  builder = litertlm_builder.LitertLmFileBuilder()
121
- builder.add_system_metadata(Metadata(key="model_name", value=f"TranslateGemma-4B-IT-{args.quant}", dtype=DType.STRING))
 
122
  builder.add_system_metadata(Metadata(key="authors", value="google", dtype=DType.STRING))
123
  builder.add_system_metadata(Metadata(key="quantization", value=args.quant, dtype=DType.STRING))
124
 
@@ -126,7 +163,10 @@ def main():
126
  str(tflite_path),
127
  model_type=litertlm_builder.TfLiteModelType.PREFILL_DECODE,
128
  )
129
- builder.add_sentencepiece_tokenizer(str(tokenizer_path))
 
 
 
130
  builder.add_llm_metadata(str(meta_file))
131
 
132
  with open(output_path, "wb") as f:
 
33
  )
34
 
35
 
36
+ # Generic Jinja template for arbitrary language pair translation.
37
+ # Supports structured XML-like input format: <src>LANG</src><dst>LANG</dst><text>TEXT</text>
38
+ # Falls back to plain text if XML tags not provided.
39
+ # Uses only Jinja2 features supported by LiteRT-LM runtime (no .get(), basic string ops).
40
+
41
+ GENERIC_TRANSLATE_TEMPLATE = \
42
  "{{ bos_token }}" \
43
  "{% for message in messages %}" \
44
  "{% if message['role'] == 'user' %}" \
45
+ "{% set content = message['content'] | trim %}" \
46
+ "{% if '<src>' in content and '<dst>' in content and '<text>' in content %}" \
47
+ "{% set src_part = content | split('<src>') | last | split('</src>') | first | trim %}" \
48
+ "{% set dst_part = content | split('<dst>') | last | split('</dst>') | first | trim %}" \
49
+ "{% set text_part = content | split('<text>') | last | split('</text>') | first | trim %}" \
50
+ "<start_of_turn>user\n" \
51
+ "Translate {{ src_part }} to {{ dst_part }}.\n" \
52
+ "Produce only the translation, without explanations:\n\n\n" \
53
+ "{{ text_part }}\n" \
54
+ "<end_of_turn>\n" \
55
+ "{% else %}" \
56
  "<start_of_turn>user\n" \
57
+ "{{ content }}\n" \
 
 
58
  "<end_of_turn>\n" \
59
+ "{% endif %}" \
60
  "{% elif message['role'] == 'assistant' %}" \
61
  "<start_of_turn>model\n" \
62
+ "{{ message['content'] | trim }}\n" \
63
  "<end_of_turn>\n" \
64
  "{% endif %}" \
65
  "{% endfor %}" \
 
67
  "<start_of_turn>model\n" \
68
  "{% endif %}"
69
 
70
+ TRANSLATE_GEMMA_JINJA_TEMPLATE = GENERIC_TRANSLATE_TEMPLATE
71
 
72
+ # Qwen3 chat template (ChatML format, no-think mode via <think>\n\n</think> prefix)
73
+ QWEN3_CHAT_TEMPLATE = \
74
+ "{% for message in messages %}" \
75
+ "{% if message['role'] == 'user' %}" \
76
+ "<|im_start|>user\n{{ message['content'] | trim }}<|im_end|>\n" \
77
+ "{% elif message['role'] == 'assistant' %}" \
78
+ "<|im_start|>assistant\n{{ message['content'] | trim }}<|im_end|>\n" \
79
+ "{% elif message['role'] == 'system' %}" \
80
+ "<|im_start|>system\n{{ message['content'] | trim }}<|im_end|>\n" \
81
+ "{% endif %}" \
82
+ "{% endfor %}" \
83
+ "{% if add_generation_prompt %}" \
84
+ "<|im_start|>assistant\n<think>\n\n</think>\n" \
85
+ "{% endif %}"
86
 
 
 
 
87
 
88
+ def build_llm_metadata_proto(max_tokens: int, model_type: str = "gemma3") -> bytes:
89
+ meta = llm_metadata_pb2.LlmMetadata()
90
+ meta.max_num_tokens = max_tokens
91
 
92
+ if model_type == "qwen3":
93
+ meta.llm_model_type.qwen3.CopyFrom(llm_model_type_pb2.Qwen3())
94
+ # Qwen3 BOS: <|endoftext|> = 151643
95
+ meta.start_token.token_ids.ids.append(151643)
96
+ # Stop tokens: <|im_end|> = 151645, <|endoftext|> = 151643
97
+ for tid in [151645, 151643]:
98
+ st = meta.stop_tokens.add()
99
+ st.token_ids.ids.append(tid)
100
+ meta.jinja_prompt_template = QWEN3_CHAT_TEMPLATE
101
+ else:
102
+ # Model type: Gemma3 (text-only variant — no vision config needed for TranslateGemma text mode)
103
+ meta.llm_model_type.gemma3.CopyFrom(llm_model_type_pb2.Gemma3())
104
+ # Start token: BOS = token id 2
105
+ meta.start_token.token_ids.ids.append(2)
106
+ # Stop tokens: EOS (id=1) and end_of_turn (id=106)
107
+ eos = meta.stop_tokens.add()
108
+ eos.token_ids.ids.append(1)
109
+ eot = meta.stop_tokens.add()
110
+ eot.token_ids.ids.append(106)
111
+ meta.jinja_prompt_template = TRANSLATE_GEMMA_JINJA_TEMPLATE
112
 
113
  return meta.SerializeToString()
114
 
 
116
  def main():
117
  ap = argparse.ArgumentParser(description="Bundle TFLite + tokenizer into .litertlm")
118
  ap.add_argument("--tflite", required=True)
119
+ ap.add_argument("--tokenizer", required=True, help="SentencePiece .model or HF tokenizer.json")
120
+ ap.add_argument("--tokenizer-type", default="sp", choices=["sp", "hf"],
121
+ help="sp=SentencePiece (default), hf=HuggingFace tokenizer.json")
122
+ ap.add_argument("--model-type", default="gemma3", choices=["gemma3", "qwen3"],
123
+ help="LlmMetadata model type (gemma3=TranslateGemma, qwen3=DictaLM/Qwen3)")
124
  ap.add_argument("--output", required=True)
125
  ap.add_argument("--max-tokens", type=int, default=2048)
126
  ap.add_argument("--quant", default="int8", help="Quantization label for metadata")
 
140
  output_path.parent.mkdir(parents=True, exist_ok=True)
141
 
142
  # Write LlmMetadata to temp file
143
+ meta_bytes = build_llm_metadata_proto(args.max_tokens, model_type=args.model_type)
144
  with tempfile.NamedTemporaryFile(suffix=".pb", delete=False) as f:
145
  meta_file = Path(f.name)
146
  f.write(meta_bytes)
 
154
  DType = litertlm_builder.DType
155
 
156
  builder = litertlm_builder.LitertLmFileBuilder()
157
+ model_label = "DictaLM-3.0-1.7B" if args.model_type == "qwen3" else "TranslateGemma-4B-IT"
158
+ builder.add_system_metadata(Metadata(key="model_name", value=f"{model_label}-{args.quant}", dtype=DType.STRING))
159
  builder.add_system_metadata(Metadata(key="authors", value="google", dtype=DType.STRING))
160
  builder.add_system_metadata(Metadata(key="quantization", value=args.quant, dtype=DType.STRING))
161
 
 
163
  str(tflite_path),
164
  model_type=litertlm_builder.TfLiteModelType.PREFILL_DECODE,
165
  )
166
+ if args.tokenizer_type == "hf":
167
+ builder.add_hf_tokenizer(str(tokenizer_path))
168
+ else:
169
+ builder.add_sentencepiece_tokenizer(str(tokenizer_path))
170
  builder.add_llm_metadata(str(meta_file))
171
 
172
  with open(output_path, "wb") as f: