File size: 21,633 Bytes
9c09026 7f4c0e9 9c09026 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 | #!/usr/bin/env python3
import argparse
import importlib
import inspect
import json
import os
import subprocess
import sys
import traceback
from pathlib import Path
os.environ.setdefault("TRANSFORMERS_NO_TORCHVISION", "1")
os.environ.setdefault("PYTHONUNBUFFERED", "1")
def log(msg): print(f"[+] {msg}", flush=True)
def warn(msg): print(f"[!] {msg}", flush=True)
def die(msg):
print(f"[x] {msg}", file=sys.stderr, flush=True)
sys.exit(1)
SUPPORTED_QUANT = {
"none",
"dynamic_int8",
"float16",
"int8",
"int4",
}
def normalize_quantize(q: str) -> str:
q = (q or "dynamic_int8").strip().lower()
aliases = {
"fp32": "none",
"no": "none",
"off": "none",
"fp16": "float16",
"f16": "float16",
"i8": "int8",
"q8": "int8",
"i4": "int4",
"q4": "int4",
}
q = aliases.get(q, q)
if q not in SUPPORTED_QUANT:
die(f"Unsupported --quantize '{q}'. Supported: {sorted(SUPPORTED_QUANT)}")
return q
def load_config(model_dir: Path):
cfg_path = model_dir / "config.json"
if not cfg_path.exists():
die(f"Missing config: {cfg_path}")
cfg = json.loads(cfg_path.read_text())
text_cfg = cfg.get("text_config", {})
return cfg, text_cfg
def inspect_arch(model_dir: Path):
cfg, text_cfg = load_config(model_dir)
info = {
"model_type": cfg.get("model_type", "unknown"),
"architecture": (cfg.get("architectures") or ["unknown"])[0],
"vocab_size": cfg.get("vocab_size", text_cfg.get("vocab_size", 262144)),
}
log(f"ARCH: {info}")
return info
def ensure_model_downloaded(model_id: str, model_dir: Path, hf_token: str):
if (model_dir / "config.json").exists():
log(f"Using existing model dir: {model_dir}")
return
log(f"Downloading {model_id} -> {model_dir}")
try:
from huggingface_hub import snapshot_download
snapshot_download(
repo_id=model_id,
local_dir=str(model_dir),
token=hf_token if hf_token else None,
local_dir_use_symlinks=False,
)
except Exception as e:
die(f"Model download failed: {e}")
def try_builders(mod, builder_names, model_dir: Path):
for fn_name in builder_names:
fn = getattr(mod, fn_name, None)
if fn is None:
continue
log(f"Trying {mod.__name__}.{fn_name} ...")
try:
m = fn(str(model_dir))
if m is None:
warn(" returned None")
continue
if isinstance(m, (tuple, list)) and len(m) > 0:
m = m[0]
if not hasattr(m, "eval"):
warn(f" unsupported return type: {type(m)}")
continue
m.eval()
log(f" success with {fn_name}")
return m
except Exception as e:
warn(f" failed: {e}")
return None
def build_translategemma_4b(checkpoint_path: str):
"""
Custom builder for TranslateGemma 4B IT (Gemma3 multimodal decoder).
Strips 'language_model.' prefix from safetensors keys so the standard
TENSOR_NAMES_SEP_QKV mapping works.
Architecture (from config.json / verified weight shapes):
34 layers, embedding_dim=2560, 8 heads, head_dim=256, 4 KV heads,
intermediate=10240, sliding_window=1024, global every 6th layer.
"""
import safetensors.torch as st_lib
import json as json_lib
from litert_torch.generative.utilities import model_builder, loader as loading_utils
from litert_torch.generative.layers import kv_cache as kv_utils
from litert_torch.generative.examples.gemma3 import decoder as gemma3_decoder
import litert_torch.generative.layers.model_config as cfg_mod
norm = cfg_mod.NormalizationConfig(
type=cfg_mod.NormalizationType.RMS_NORM, epsilon=1e-6, zero_centered=True
)
ff = cfg_mod.FeedForwardConfig(
type=cfg_mod.FeedForwardType.GATED,
activation=cfg_mod.ActivationConfig(cfg_mod.ActivationType.GELU_TANH),
intermediate_size=10240,
pre_ff_norm_config=norm, post_ff_norm_config=norm,
)
def blk(idx):
attn = cfg_mod.AttentionConfig(
num_heads=8, head_dim=256, num_query_groups=4,
rotary_base=1_000_000 if (idx + 1) % 6 == 0 else 10_000,
rotary_percentage=1.0, qkv_transpose_before_split=True,
qkv_fused_interleaved=False,
query_norm_config=norm, key_norm_config=norm, logit_softcap=None,
sliding_window_size=1024,
attn_type=cfg_mod.AttentionType.GLOBAL if (idx + 1) % 6 == 0
else cfg_mod.AttentionType.LOCAL_SLIDING,
)
return cfg_mod.TransformerBlockConfig(
attn_config=attn, ff_config=ff,
pre_attention_norm_config=norm, post_attention_norm_config=norm,
)
model_config = cfg_mod.ModelConfig(
vocab_size=262208, num_layers=34, max_seq_len=8192,
embedding_dim=2560, embedding_scale=2560 ** 0.5,
block_configs=[blk(i) for i in range(34)],
final_norm_config=norm, lm_head_use_bias=False, final_logit_softcap=None,
)
tensor_names = loading_utils.ModelLoader.TensorNames(
ff_up_proj="model.layers.{}.mlp.up_proj",
ff_down_proj="model.layers.{}.mlp.down_proj",
ff_gate_proj="model.layers.{}.mlp.gate_proj",
attn_query_proj="model.layers.{}.self_attn.q_proj",
attn_key_proj="model.layers.{}.self_attn.k_proj",
attn_value_proj="model.layers.{}.self_attn.v_proj",
attn_output_proj="model.layers.{}.self_attn.o_proj",
attn_query_norm="model.layers.{}.self_attn.q_norm",
attn_key_norm="model.layers.{}.self_attn.k_norm",
pre_attn_norm="model.layers.{}.input_layernorm",
post_attn_norm="model.layers.{}.post_attention_layernorm",
pre_ff_norm="model.layers.{}.pre_feedforward_layernorm",
post_ff_norm="model.layers.{}.post_feedforward_layernorm",
embedding="model.embed_tokens",
final_norm="model.norm",
lm_head=None,
)
def custom_loader(path: str):
idx = json_lib.loads((Path(path) / "model.safetensors.index.json").read_text())
out = {}
for fname in set(idx["weight_map"].values()):
for k, v in st_lib.load_file(str(Path(path) / fname)).items():
out[k[len("language_model."):] if k.startswith("language_model.") else k] = v
return out
return model_builder.build_decoder_only_model(
checkpoint_path=checkpoint_path,
config=model_config,
tensor_names=tensor_names,
model_class=gemma3_decoder.Decoder,
custom_loader=custom_loader,
)
def strategy1_litert_native(model_dir: Path, out_dir: Path, model_type: str, quantize: str, prefill: int, kvcache: int):
"""
Native LiteRT conversion with explicit quantization support.
"""
log("Strategy 1: litert-torch native")
from litert_torch.generative.utilities import converter
from litert_torch.generative.utilities.export_config import ExportConfig
from litert_torch.generative.layers import kv_cache
export_config = ExportConfig()
export_config.kvcache_layout = kv_cache.KV_LAYOUT_TRANSPOSED
export_config.mask_as_input = True
# Map our quantize flags to converter's QuantizationName values
QUANT_MAP = {
"none": "none",
"dynamic_int8": "dynamic_int8",
"int8": "weight_only_int8",
"float16": "fp16",
"int4": "dynamic_int4_block128",
}
quant_for_converter = QUANT_MAP.get(quantize)
if quant_for_converter is None:
warn(f"No converter mapping for '{quantize}', falling back to Strategy 2")
return None
model = None
if model_type == "gemma3":
# Try custom 4B builder first (handles TranslateGemma 4B multimodal weight prefix)
log("Trying custom build_translategemma_4b ...")
try:
model = build_translategemma_4b(str(model_dir))
if model is not None:
model.eval()
log(" build_translategemma_4b success")
except Exception as e:
warn(f" build_translategemma_4b failed: {e}")
model = None
if model is None:
mod = importlib.import_module("litert_torch.generative.examples.gemma3.gemma3")
available = [n for n in dir(mod) if n.startswith("build_model")]
log(f"Gemma3 builders available: {available}")
preferred = ["build_model_4b", "build_model_2b", "build_model_1b", "build_model_270m", "build_model"]
ordered = [n for n in preferred if n in available] + [n for n in available if n not in preferred]
model = try_builders(mod, ordered, model_dir)
elif model_type in ("gemma", "gemma2"):
mod = importlib.import_module("litert_torch.generative.examples.gemma2.gemma2")
available = [n for n in dir(mod) if n.startswith("build_model")]
log(f"Gemma2 builders available: {available}")
preferred = ["build_model_4b", "build_model_2b", "build_model"]
ordered = [n for n in preferred if n in available] + [n for n in available if n not in preferred]
model = try_builders(mod, ordered, model_dir)
else:
warn(f"Model type '{model_type}' not handled by native strategy")
return None
if model is None:
warn("Strategy 1 did not find a compatible builder")
return None
converter.convert_to_tflite(
model,
output_path=str(out_dir),
output_name_prefix=f"translategemma-4b-it-{quantize}",
prefill_seq_len=prefill,
kv_cache_max_len=kvcache,
quantize=quant_for_converter,
export_config=export_config,
)
produced = sorted(out_dir.glob(f"*{quantize}*.tflite")) or sorted(out_dir.glob("*.tflite"))
return produced[0] if produced else None
def strategy2_generic(model_dir: Path, out_dir: Path, prefill: int, quantize: str):
"""
Generic fallback conversion (logits-only). Always exports float32; use
strategy3_post_tflite_quantize() afterwards for real int4/int8 compression.
"""
log("Strategy 2: ai_edge_torch generic (wrapped logits-only)")
import torch
try:
import litert_torch as ai_edge_torch
except Exception:
import ai_edge_torch # deprecated fallback
from transformers import AutoConfig, AutoModelForCausalLM
dtype = torch.float32
if quantize == "float16":
dtype = torch.float16
cfg = AutoConfig.from_pretrained(str(model_dir), trust_remote_code=True)
vocab = getattr(cfg, "vocab_size", None)
if vocab is None and hasattr(cfg, "text_config"):
vocab = getattr(cfg.text_config, "vocab_size", None)
vocab = int(vocab or 262144)
log(f"Loading HF model on CPU with dtype={dtype} ...")
base_model = AutoModelForCausalLM.from_pretrained(
str(model_dir),
trust_remote_code=True,
torch_dtype=dtype,
)
base_model.eval()
# TFLite embedding_lookup does not support f16 weights — keep embedding in float32
if dtype == torch.float16:
log("Casting embedding and lm_head to float32 for TFLite compatibility")
if hasattr(base_model, "model") and hasattr(base_model.model, "embed_tokens"):
base_model.model.embed_tokens = base_model.model.embed_tokens.to(torch.float32)
if hasattr(base_model, "lm_head"):
base_model.lm_head = base_model.lm_head.to(torch.float32)
if hasattr(base_model, "config") and hasattr(base_model.config, "use_cache"):
base_model.config.use_cache = False
class LogitsOnlyWrapper(torch.nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
def forward(self, input_ids):
out = self.model(
input_ids=input_ids,
use_cache=False,
return_dict=False,
)
logits = out[0] if isinstance(out, (tuple, list)) else out.logits
return logits
wrapped = LogitsOnlyWrapper(base_model).eval()
sample_ids = torch.randint(0, vocab, (1, min(prefill, 128)), dtype=torch.int64)
# Always export float32 base; int4/int8 quantization applied post-export via Strategy 3
out_file = out_dir / f"translategemma-4b-it-generic-none.tflite"
edge_model = ai_edge_torch.convert(wrapped, (sample_ids,))
edge_model.export(str(out_file))
return out_file if out_file.exists() else None
def strategy3_post_tflite_quantize(tflite_in: Path, out_dir: Path, quantize: str):
"""
Post-conversion weight quantization applied directly to a TFLite flatbuffer
using ai_edge_quantizer (bundled with litert_torch).
Supported modes and their recipes (per get_supported_layer_schemes()):
int4 -> INT4 DYNAMIC_RANGE BLOCKWISE_128 (~2 GB for 4B model)
int8 -> INT8 WEIGHT_ONLY CHANNELWISE (~4 GB)
dynamic_int8 -> INT8 DYNAMIC_RANGE CHANNELWISE (~4 GB)
float16 -> FP16 WEIGHT_ONLY FLOAT_CAST (~8 GB)
"""
log(f"Strategy 3: post-TFLite quantization ({quantize}) on {tflite_in.name}")
from litert_torch.generative.quantize import quant_attrs as qa, quant_recipe, quant_recipe_utils
from litert_torch.quantize import translate_recipe
if quantize == "int4":
layer = quant_recipe_utils.create_layer_quant_dynamic(qa.Dtype.INT4, qa.Granularity.BLOCKWISE_128)
elif quantize == "int8":
layer = quant_recipe_utils.create_layer_quant_weight_only(qa.Dtype.INT8, qa.Granularity.CHANNELWISE)
elif quantize == "dynamic_int8":
layer = quant_recipe_utils.create_layer_quant_dynamic(qa.Dtype.INT8, qa.Granularity.CHANNELWISE)
elif quantize == "float16":
layer = quant_recipe_utils.create_layer_quant_fp16()
else:
warn(f"Strategy 3: no post-TFLite recipe for '{quantize}', skipping")
return None
gen_recipe = quant_recipe.GenerativeQuantRecipe(default=layer)
ai_recipe = translate_recipe.translate_to_ai_edge_recipe(gen_recipe)
model_bytes = tflite_in.read_bytes()
log(f" Input: {len(model_bytes) / 1024**3:.2f} GB — quantizing ...")
quantized_bytes = translate_recipe.quantize_model(model_bytes, ai_recipe)
out_file = out_dir / f"translategemma-4b-it-{quantize}.tflite"
out_file.write_bytes(quantized_bytes)
log(f" Output: {len(quantized_bytes) / 1024**3:.2f} GB -> {out_file}")
return out_file
def ensure_tokenizer_model(model_dir: Path):
tok_model = model_dir / "tokenizer.model"
if tok_model.exists():
return tok_model
tok_json = model_dir / "tokenizer.json"
if not tok_json.exists():
die(f"Missing tokenizer files: neither {tok_model} nor {tok_json} exists")
log("Converting tokenizer.json -> tokenizer.model")
cmd = [
sys.executable,
"-m",
"litert_torch.generative.tools.tokenizer_to_sentencepiece",
f"--checkpoint={model_dir}",
f"--output_path={tok_model}",
]
res = subprocess.run(cmd, text=True, capture_output=True)
if res.stdout:
print(res.stdout, flush=True)
if res.returncode != 0:
if res.stderr:
print(res.stderr, file=sys.stderr, flush=True)
die("Tokenizer conversion failed")
if not tok_model.exists():
die("Tokenizer conversion reported success but tokenizer.model is missing")
return tok_model
def bundle_task(tflite_file: Path, tokenizer_model: Path, task_file: Path):
log(f"Bundling .task -> {task_file}")
from mediapipe.tasks.python.genai import bundler
task_file.parent.mkdir(parents=True, exist_ok=True)
sig = inspect.signature(bundler.BundleConfig)
params = sig.parameters
log(f"BundleConfig params: {list(params.keys())}")
kwargs = {
"tflite_model": str(tflite_file),
"tokenizer_model": str(tokenizer_model),
"output_filename": str(task_file),
}
if "start_token" in params:
kwargs["start_token"] = "<bos>"
elif "start_tokens" in params:
kwargs["start_tokens"] = ["<bos>"]
if "stop_tokens" in params:
kwargs["stop_tokens"] = ["<eos>"]
if "prompt_prefix" in params:
kwargs["prompt_prefix"] = ""
if "prompt_suffix" in params:
kwargs["prompt_suffix"] = ""
kwargs = {k: v for k, v in kwargs.items() if k in params}
cfg = bundler.BundleConfig(**kwargs)
bundler.create_bundle(cfg)
def main():
ap = argparse.ArgumentParser(description="TranslateGemma -> Android .task converter")
ap.add_argument("--model-id", default="google/translategemma-4b-it")
ap.add_argument("--model-dir", default="./translategemma-4b-it")
ap.add_argument("--tflite-dir", default="./tflite_output")
ap.add_argument("--output-dir", default="./output")
ap.add_argument("--task-file", default="./output/translategemma-4b-it-android.task")
ap.add_argument("--quantize", default="dynamic_int8", help=f"One of: {sorted(SUPPORTED_QUANT)}")
ap.add_argument("--prefill-seq-len", "--prefill", dest="prefill_seq_len", type=int, default=1024)
ap.add_argument("--kv-cache-max-len", "--kvcache", dest="kv_cache_max_len", type=int, default=1024)
ap.add_argument("--skip-strategy1", action="store_true")
ap.add_argument("--bundle-only", action="store_true", help="Skip conversion; only bundle existing TFLite")
ap.add_argument("--existing-tflite", default="", help="Path to an existing .tflite to bundle")
ap.add_argument("--allow-no-token", action="store_true", help="Allow model download from public repo without HF_TOKEN")
args = ap.parse_args()
q = normalize_quantize(args.quantize)
hf_token = os.environ.get("HF_TOKEN", "").strip()
if not hf_token and not args.allow_no_token:
warn("HF_TOKEN is not set. If model is public, you can pass --allow-no-token.")
model_dir = Path(args.model_dir)
tflite_dir = Path(args.tflite_dir)
output_dir = Path(args.output_dir)
task_file = Path(args.task_file)
tflite_dir.mkdir(parents=True, exist_ok=True)
output_dir.mkdir(parents=True, exist_ok=True)
tflite_file = None
if args.bundle_only:
if not args.existing_tflite:
die("--bundle-only requires --existing-tflite /path/to/model.tflite")
tflite_file = Path(args.existing_tflite)
if not tflite_file.exists():
die(f"Existing tflite not found: {tflite_file}")
log(f"Bundle-only mode using: {tflite_file}")
# Apply post-TFLite quantization if requested
if q != "none":
try:
quantized = strategy3_post_tflite_quantize(tflite_file, tflite_dir, q)
if quantized:
tflite_file = quantized
log(f"Strategy 3 success: {tflite_file}")
else:
warn("Strategy 3 skipped; bundling input as-is")
except Exception as e:
warn(f"Strategy 3 failed: {e}")
traceback.print_exc()
else:
ensure_model_downloaded(args.model_id, model_dir, hf_token if hf_token else "")
arch = inspect_arch(model_dir)
model_type = arch["model_type"]
strategy1_succeeded = False
if not args.skip_strategy1:
try:
tflite_file = strategy1_litert_native(
model_dir=model_dir,
out_dir=tflite_dir,
model_type=model_type,
quantize=q,
prefill=args.prefill_seq_len,
kvcache=args.kv_cache_max_len,
)
if tflite_file:
log(f"Strategy 1 success: {tflite_file}")
strategy1_succeeded = True
except Exception as e:
warn(f"Strategy 1 failed: {e}")
traceback.print_exc()
if not tflite_file:
try:
tflite_file = strategy2_generic(
model_dir=model_dir,
out_dir=tflite_dir,
prefill=args.prefill_seq_len,
quantize=q,
)
if tflite_file:
log(f"Strategy 2 success: {tflite_file}")
warn("Generic TFLite may not have MediaPipe LLM prefill/decode signatures.")
except Exception as e:
warn(f"Strategy 2 failed: {e}")
traceback.print_exc()
# Strategy 3: post-TFLite quantization — only when Strategy 2 was used (Strategy 1
# already applies quantization natively via the converter).
if tflite_file and not strategy1_succeeded and q != "none":
try:
quantized = strategy3_post_tflite_quantize(tflite_file, tflite_dir, q)
if quantized:
tflite_file = quantized
log(f"Strategy 3 success: {tflite_file}")
else:
warn("Strategy 3 skipped; bundling unquantized model")
except Exception as e:
warn(f"Strategy 3 failed: {e}")
traceback.print_exc()
if not tflite_file:
die("All conversion strategies failed")
tokenizer_model = ensure_tokenizer_model(model_dir)
bundle_task(tflite_file, tokenizer_model, task_file)
log(f"DONE: {task_file}")
if task_file.exists():
log(f"Size: {task_file.stat().st_size / (1024 * 1024):.2f} MB")
if __name__ == "__main__":
main() |