| |
| import argparse |
| import importlib |
| import inspect |
| import json |
| import os |
| import subprocess |
| import sys |
| import traceback |
| from pathlib import Path |
|
|
| os.environ.setdefault("TRANSFORMERS_NO_TORCHVISION", "1") |
| os.environ.setdefault("PYTHONUNBUFFERED", "1") |
|
|
|
|
| def log(msg): print(f"[+] {msg}", flush=True) |
| def warn(msg): print(f"[!] {msg}", flush=True) |
| def die(msg): |
| print(f"[x] {msg}", file=sys.stderr, flush=True) |
| sys.exit(1) |
|
|
|
|
| SUPPORTED_QUANT = { |
| "none", |
| "dynamic_int8", |
| "float16", |
| "int8", |
| "int4", |
| } |
|
|
|
|
| def normalize_quantize(q: str) -> str: |
| q = (q or "dynamic_int8").strip().lower() |
| aliases = { |
| "fp32": "none", |
| "no": "none", |
| "off": "none", |
| "fp16": "float16", |
| "f16": "float16", |
| "i8": "int8", |
| "q8": "int8", |
| "i4": "int4", |
| "q4": "int4", |
| } |
| q = aliases.get(q, q) |
| if q not in SUPPORTED_QUANT: |
| die(f"Unsupported --quantize '{q}'. Supported: {sorted(SUPPORTED_QUANT)}") |
| return q |
|
|
|
|
| def load_config(model_dir: Path): |
| cfg_path = model_dir / "config.json" |
| if not cfg_path.exists(): |
| die(f"Missing config: {cfg_path}") |
| cfg = json.loads(cfg_path.read_text()) |
| text_cfg = cfg.get("text_config", {}) |
| return cfg, text_cfg |
|
|
|
|
| def inspect_arch(model_dir: Path): |
| cfg, text_cfg = load_config(model_dir) |
| info = { |
| "model_type": cfg.get("model_type", "unknown"), |
| "architecture": (cfg.get("architectures") or ["unknown"])[0], |
| "vocab_size": cfg.get("vocab_size", text_cfg.get("vocab_size", 262144)), |
| } |
| log(f"ARCH: {info}") |
| return info |
|
|
|
|
| def ensure_model_downloaded(model_id: str, model_dir: Path, hf_token: str): |
| if (model_dir / "config.json").exists(): |
| log(f"Using existing model dir: {model_dir}") |
| return |
|
|
| log(f"Downloading {model_id} -> {model_dir}") |
| try: |
| from huggingface_hub import snapshot_download |
| snapshot_download( |
| repo_id=model_id, |
| local_dir=str(model_dir), |
| token=hf_token if hf_token else None, |
| local_dir_use_symlinks=False, |
| ) |
| except Exception as e: |
| die(f"Model download failed: {e}") |
|
|
|
|
| def try_builders(mod, builder_names, model_dir: Path): |
| for fn_name in builder_names: |
| fn = getattr(mod, fn_name, None) |
| if fn is None: |
| continue |
| log(f"Trying {mod.__name__}.{fn_name} ...") |
| try: |
| m = fn(str(model_dir)) |
| if m is None: |
| warn(" returned None") |
| continue |
| if isinstance(m, (tuple, list)) and len(m) > 0: |
| m = m[0] |
| if not hasattr(m, "eval"): |
| warn(f" unsupported return type: {type(m)}") |
| continue |
| m.eval() |
| log(f" success with {fn_name}") |
| return m |
| except Exception as e: |
| warn(f" failed: {e}") |
| return None |
|
|
|
|
| def build_translategemma_4b(checkpoint_path: str): |
| """ |
| Custom builder for TranslateGemma 4B IT (Gemma3 multimodal decoder). |
| Strips 'language_model.' prefix from safetensors keys so the standard |
| TENSOR_NAMES_SEP_QKV mapping works. |
| |
| Architecture (from config.json / verified weight shapes): |
| 34 layers, embedding_dim=2560, 8 heads, head_dim=256, 4 KV heads, |
| intermediate=10240, sliding_window=1024, global every 6th layer. |
| """ |
| import safetensors.torch as st_lib |
| import json as json_lib |
| from litert_torch.generative.utilities import model_builder, loader as loading_utils |
| from litert_torch.generative.layers import kv_cache as kv_utils |
| from litert_torch.generative.examples.gemma3 import decoder as gemma3_decoder |
| import litert_torch.generative.layers.model_config as cfg_mod |
|
|
| norm = cfg_mod.NormalizationConfig( |
| type=cfg_mod.NormalizationType.RMS_NORM, epsilon=1e-6, zero_centered=True |
| ) |
| ff = cfg_mod.FeedForwardConfig( |
| type=cfg_mod.FeedForwardType.GATED, |
| activation=cfg_mod.ActivationConfig(cfg_mod.ActivationType.GELU_TANH), |
| intermediate_size=10240, |
| pre_ff_norm_config=norm, post_ff_norm_config=norm, |
| ) |
|
|
| def blk(idx): |
| attn = cfg_mod.AttentionConfig( |
| num_heads=8, head_dim=256, num_query_groups=4, |
| rotary_base=1_000_000 if (idx + 1) % 6 == 0 else 10_000, |
| rotary_percentage=1.0, qkv_transpose_before_split=True, |
| qkv_fused_interleaved=False, |
| query_norm_config=norm, key_norm_config=norm, logit_softcap=None, |
| sliding_window_size=1024, |
| attn_type=cfg_mod.AttentionType.GLOBAL if (idx + 1) % 6 == 0 |
| else cfg_mod.AttentionType.LOCAL_SLIDING, |
| ) |
| return cfg_mod.TransformerBlockConfig( |
| attn_config=attn, ff_config=ff, |
| pre_attention_norm_config=norm, post_attention_norm_config=norm, |
| ) |
|
|
| model_config = cfg_mod.ModelConfig( |
| vocab_size=262208, num_layers=34, max_seq_len=8192, |
| embedding_dim=2560, embedding_scale=2560 ** 0.5, |
| block_configs=[blk(i) for i in range(34)], |
| final_norm_config=norm, lm_head_use_bias=False, final_logit_softcap=None, |
| ) |
|
|
| tensor_names = loading_utils.ModelLoader.TensorNames( |
| ff_up_proj="model.layers.{}.mlp.up_proj", |
| ff_down_proj="model.layers.{}.mlp.down_proj", |
| ff_gate_proj="model.layers.{}.mlp.gate_proj", |
| attn_query_proj="model.layers.{}.self_attn.q_proj", |
| attn_key_proj="model.layers.{}.self_attn.k_proj", |
| attn_value_proj="model.layers.{}.self_attn.v_proj", |
| attn_output_proj="model.layers.{}.self_attn.o_proj", |
| attn_query_norm="model.layers.{}.self_attn.q_norm", |
| attn_key_norm="model.layers.{}.self_attn.k_norm", |
| pre_attn_norm="model.layers.{}.input_layernorm", |
| post_attn_norm="model.layers.{}.post_attention_layernorm", |
| pre_ff_norm="model.layers.{}.pre_feedforward_layernorm", |
| post_ff_norm="model.layers.{}.post_feedforward_layernorm", |
| embedding="model.embed_tokens", |
| final_norm="model.norm", |
| lm_head=None, |
| ) |
|
|
| def custom_loader(path: str): |
| idx = json_lib.loads((Path(path) / "model.safetensors.index.json").read_text()) |
| out = {} |
| for fname in set(idx["weight_map"].values()): |
| for k, v in st_lib.load_file(str(Path(path) / fname)).items(): |
| out[k[len("language_model."):] if k.startswith("language_model.") else k] = v |
| return out |
|
|
| return model_builder.build_decoder_only_model( |
| checkpoint_path=checkpoint_path, |
| config=model_config, |
| tensor_names=tensor_names, |
| model_class=gemma3_decoder.Decoder, |
| custom_loader=custom_loader, |
| ) |
|
|
|
|
| def strategy1_litert_native(model_dir: Path, out_dir: Path, model_type: str, quantize: str, prefill: int, kvcache: int): |
| """ |
| Native LiteRT conversion with explicit quantization support. |
| """ |
| log("Strategy 1: litert-torch native") |
|
|
| from litert_torch.generative.utilities import converter |
| from litert_torch.generative.utilities.export_config import ExportConfig |
| from litert_torch.generative.layers import kv_cache |
|
|
| export_config = ExportConfig() |
| export_config.kvcache_layout = kv_cache.KV_LAYOUT_TRANSPOSED |
| export_config.mask_as_input = True |
|
|
| |
| QUANT_MAP = { |
| "none": "none", |
| "dynamic_int8": "dynamic_int8", |
| "int8": "weight_only_int8", |
| "float16": "fp16", |
| "int4": "dynamic_int4_block128", |
| } |
| quant_for_converter = QUANT_MAP.get(quantize) |
| if quant_for_converter is None: |
| warn(f"No converter mapping for '{quantize}', falling back to Strategy 2") |
| return None |
|
|
| model = None |
|
|
| if model_type == "gemma3": |
| |
| log("Trying custom build_translategemma_4b ...") |
| try: |
| model = build_translategemma_4b(str(model_dir)) |
| if model is not None: |
| model.eval() |
| log(" build_translategemma_4b success") |
| except Exception as e: |
| warn(f" build_translategemma_4b failed: {e}") |
| model = None |
|
|
| if model is None: |
| mod = importlib.import_module("litert_torch.generative.examples.gemma3.gemma3") |
| available = [n for n in dir(mod) if n.startswith("build_model")] |
| log(f"Gemma3 builders available: {available}") |
| preferred = ["build_model_4b", "build_model_2b", "build_model_1b", "build_model_270m", "build_model"] |
| ordered = [n for n in preferred if n in available] + [n for n in available if n not in preferred] |
| model = try_builders(mod, ordered, model_dir) |
|
|
| elif model_type in ("gemma", "gemma2"): |
| mod = importlib.import_module("litert_torch.generative.examples.gemma2.gemma2") |
| available = [n for n in dir(mod) if n.startswith("build_model")] |
| log(f"Gemma2 builders available: {available}") |
| preferred = ["build_model_4b", "build_model_2b", "build_model"] |
| ordered = [n for n in preferred if n in available] + [n for n in available if n not in preferred] |
| model = try_builders(mod, ordered, model_dir) |
|
|
| else: |
| warn(f"Model type '{model_type}' not handled by native strategy") |
| return None |
|
|
| if model is None: |
| warn("Strategy 1 did not find a compatible builder") |
| return None |
|
|
| converter.convert_to_tflite( |
| model, |
| output_path=str(out_dir), |
| output_name_prefix=f"translategemma-4b-it-{quantize}", |
| prefill_seq_len=prefill, |
| kv_cache_max_len=kvcache, |
| quantize=quant_for_converter, |
| export_config=export_config, |
| ) |
|
|
| produced = sorted(out_dir.glob(f"*{quantize}*.tflite")) or sorted(out_dir.glob("*.tflite")) |
| return produced[0] if produced else None |
|
|
|
|
| def strategy2_generic(model_dir: Path, out_dir: Path, prefill: int, quantize: str): |
| """ |
| Generic fallback conversion (logits-only). Always exports float32; use |
| strategy3_post_tflite_quantize() afterwards for real int4/int8 compression. |
| """ |
| log("Strategy 2: ai_edge_torch generic (wrapped logits-only)") |
|
|
| import torch |
| try: |
| import litert_torch as ai_edge_torch |
| except Exception: |
| import ai_edge_torch |
|
|
| from transformers import AutoConfig, AutoModelForCausalLM |
|
|
| dtype = torch.float32 |
| if quantize == "float16": |
| dtype = torch.float16 |
|
|
| cfg = AutoConfig.from_pretrained(str(model_dir), trust_remote_code=True) |
| vocab = getattr(cfg, "vocab_size", None) |
| if vocab is None and hasattr(cfg, "text_config"): |
| vocab = getattr(cfg.text_config, "vocab_size", None) |
| vocab = int(vocab or 262144) |
|
|
| log(f"Loading HF model on CPU with dtype={dtype} ...") |
| base_model = AutoModelForCausalLM.from_pretrained( |
| str(model_dir), |
| trust_remote_code=True, |
| torch_dtype=dtype, |
| ) |
| base_model.eval() |
|
|
| |
| if dtype == torch.float16: |
| log("Casting embedding and lm_head to float32 for TFLite compatibility") |
| if hasattr(base_model, "model") and hasattr(base_model.model, "embed_tokens"): |
| base_model.model.embed_tokens = base_model.model.embed_tokens.to(torch.float32) |
| if hasattr(base_model, "lm_head"): |
| base_model.lm_head = base_model.lm_head.to(torch.float32) |
|
|
| if hasattr(base_model, "config") and hasattr(base_model.config, "use_cache"): |
| base_model.config.use_cache = False |
|
|
| class LogitsOnlyWrapper(torch.nn.Module): |
| def __init__(self, model): |
| super().__init__() |
| self.model = model |
|
|
| def forward(self, input_ids): |
| out = self.model( |
| input_ids=input_ids, |
| use_cache=False, |
| return_dict=False, |
| ) |
| logits = out[0] if isinstance(out, (tuple, list)) else out.logits |
| return logits |
|
|
| wrapped = LogitsOnlyWrapper(base_model).eval() |
| sample_ids = torch.randint(0, vocab, (1, min(prefill, 128)), dtype=torch.int64) |
|
|
| |
| out_file = out_dir / f"translategemma-4b-it-generic-none.tflite" |
| edge_model = ai_edge_torch.convert(wrapped, (sample_ids,)) |
| edge_model.export(str(out_file)) |
|
|
| return out_file if out_file.exists() else None |
|
|
|
|
| def strategy3_post_tflite_quantize(tflite_in: Path, out_dir: Path, quantize: str): |
| """ |
| Post-conversion weight quantization applied directly to a TFLite flatbuffer |
| using ai_edge_quantizer (bundled with litert_torch). |
| |
| Supported modes and their recipes (per get_supported_layer_schemes()): |
| int4 -> INT4 DYNAMIC_RANGE BLOCKWISE_128 (~2 GB for 4B model) |
| int8 -> INT8 WEIGHT_ONLY CHANNELWISE (~4 GB) |
| dynamic_int8 -> INT8 DYNAMIC_RANGE CHANNELWISE (~4 GB) |
| float16 -> FP16 WEIGHT_ONLY FLOAT_CAST (~8 GB) |
| """ |
| log(f"Strategy 3: post-TFLite quantization ({quantize}) on {tflite_in.name}") |
|
|
| from litert_torch.generative.quantize import quant_attrs as qa, quant_recipe, quant_recipe_utils |
| from litert_torch.quantize import translate_recipe |
|
|
| if quantize == "int4": |
| layer = quant_recipe_utils.create_layer_quant_dynamic(qa.Dtype.INT4, qa.Granularity.BLOCKWISE_128) |
| elif quantize == "int8": |
| layer = quant_recipe_utils.create_layer_quant_weight_only(qa.Dtype.INT8, qa.Granularity.CHANNELWISE) |
| elif quantize == "dynamic_int8": |
| layer = quant_recipe_utils.create_layer_quant_dynamic(qa.Dtype.INT8, qa.Granularity.CHANNELWISE) |
| elif quantize == "float16": |
| layer = quant_recipe_utils.create_layer_quant_fp16() |
| else: |
| warn(f"Strategy 3: no post-TFLite recipe for '{quantize}', skipping") |
| return None |
|
|
| gen_recipe = quant_recipe.GenerativeQuantRecipe(default=layer) |
| ai_recipe = translate_recipe.translate_to_ai_edge_recipe(gen_recipe) |
|
|
| model_bytes = tflite_in.read_bytes() |
| log(f" Input: {len(model_bytes) / 1024**3:.2f} GB — quantizing ...") |
| quantized_bytes = translate_recipe.quantize_model(model_bytes, ai_recipe) |
|
|
| out_file = out_dir / f"translategemma-4b-it-{quantize}.tflite" |
| out_file.write_bytes(quantized_bytes) |
| log(f" Output: {len(quantized_bytes) / 1024**3:.2f} GB -> {out_file}") |
| return out_file |
|
|
|
|
| def ensure_tokenizer_model(model_dir: Path): |
| tok_model = model_dir / "tokenizer.model" |
| if tok_model.exists(): |
| return tok_model |
|
|
| tok_json = model_dir / "tokenizer.json" |
| if not tok_json.exists(): |
| die(f"Missing tokenizer files: neither {tok_model} nor {tok_json} exists") |
|
|
| log("Converting tokenizer.json -> tokenizer.model") |
| cmd = [ |
| sys.executable, |
| "-m", |
| "litert_torch.generative.tools.tokenizer_to_sentencepiece", |
| f"--checkpoint={model_dir}", |
| f"--output_path={tok_model}", |
| ] |
| res = subprocess.run(cmd, text=True, capture_output=True) |
| if res.stdout: |
| print(res.stdout, flush=True) |
| if res.returncode != 0: |
| if res.stderr: |
| print(res.stderr, file=sys.stderr, flush=True) |
| die("Tokenizer conversion failed") |
|
|
| if not tok_model.exists(): |
| die("Tokenizer conversion reported success but tokenizer.model is missing") |
|
|
| return tok_model |
|
|
|
|
| def bundle_task(tflite_file: Path, tokenizer_model: Path, task_file: Path): |
| log(f"Bundling .task -> {task_file}") |
| from mediapipe.tasks.python.genai import bundler |
|
|
| task_file.parent.mkdir(parents=True, exist_ok=True) |
|
|
| sig = inspect.signature(bundler.BundleConfig) |
| params = sig.parameters |
| log(f"BundleConfig params: {list(params.keys())}") |
|
|
| kwargs = { |
| "tflite_model": str(tflite_file), |
| "tokenizer_model": str(tokenizer_model), |
| "output_filename": str(task_file), |
| } |
|
|
| if "start_token" in params: |
| kwargs["start_token"] = "<bos>" |
| elif "start_tokens" in params: |
| kwargs["start_tokens"] = ["<bos>"] |
|
|
| if "stop_tokens" in params: |
| kwargs["stop_tokens"] = ["<eos>"] |
|
|
| if "prompt_prefix" in params: |
| kwargs["prompt_prefix"] = "" |
| if "prompt_suffix" in params: |
| kwargs["prompt_suffix"] = "" |
|
|
| kwargs = {k: v for k, v in kwargs.items() if k in params} |
| cfg = bundler.BundleConfig(**kwargs) |
| bundler.create_bundle(cfg) |
|
|
|
|
| def main(): |
| ap = argparse.ArgumentParser(description="TranslateGemma -> Android .task converter") |
|
|
| ap.add_argument("--model-id", default="google/translategemma-4b-it") |
| ap.add_argument("--model-dir", default="./translategemma-4b-it") |
| ap.add_argument("--tflite-dir", default="./tflite_output") |
| ap.add_argument("--output-dir", default="./output") |
| ap.add_argument("--task-file", default="./output/translategemma-4b-it-android.task") |
|
|
| ap.add_argument("--quantize", default="dynamic_int8", help=f"One of: {sorted(SUPPORTED_QUANT)}") |
|
|
| ap.add_argument("--prefill-seq-len", "--prefill", dest="prefill_seq_len", type=int, default=1024) |
| ap.add_argument("--kv-cache-max-len", "--kvcache", dest="kv_cache_max_len", type=int, default=1024) |
|
|
| ap.add_argument("--skip-strategy1", action="store_true") |
| ap.add_argument("--bundle-only", action="store_true", help="Skip conversion; only bundle existing TFLite") |
| ap.add_argument("--existing-tflite", default="", help="Path to an existing .tflite to bundle") |
| ap.add_argument("--allow-no-token", action="store_true", help="Allow model download from public repo without HF_TOKEN") |
|
|
| args = ap.parse_args() |
| q = normalize_quantize(args.quantize) |
|
|
| hf_token = os.environ.get("HF_TOKEN", "").strip() |
| if not hf_token and not args.allow_no_token: |
| warn("HF_TOKEN is not set. If model is public, you can pass --allow-no-token.") |
|
|
| model_dir = Path(args.model_dir) |
| tflite_dir = Path(args.tflite_dir) |
| output_dir = Path(args.output_dir) |
| task_file = Path(args.task_file) |
|
|
| tflite_dir.mkdir(parents=True, exist_ok=True) |
| output_dir.mkdir(parents=True, exist_ok=True) |
|
|
| tflite_file = None |
|
|
| if args.bundle_only: |
| if not args.existing_tflite: |
| die("--bundle-only requires --existing-tflite /path/to/model.tflite") |
| tflite_file = Path(args.existing_tflite) |
| if not tflite_file.exists(): |
| die(f"Existing tflite not found: {tflite_file}") |
| log(f"Bundle-only mode using: {tflite_file}") |
| |
| if q != "none": |
| try: |
| quantized = strategy3_post_tflite_quantize(tflite_file, tflite_dir, q) |
| if quantized: |
| tflite_file = quantized |
| log(f"Strategy 3 success: {tflite_file}") |
| else: |
| warn("Strategy 3 skipped; bundling input as-is") |
| except Exception as e: |
| warn(f"Strategy 3 failed: {e}") |
| traceback.print_exc() |
| else: |
| ensure_model_downloaded(args.model_id, model_dir, hf_token if hf_token else "") |
| arch = inspect_arch(model_dir) |
| model_type = arch["model_type"] |
|
|
| strategy1_succeeded = False |
| if not args.skip_strategy1: |
| try: |
| tflite_file = strategy1_litert_native( |
| model_dir=model_dir, |
| out_dir=tflite_dir, |
| model_type=model_type, |
| quantize=q, |
| prefill=args.prefill_seq_len, |
| kvcache=args.kv_cache_max_len, |
| ) |
| if tflite_file: |
| log(f"Strategy 1 success: {tflite_file}") |
| strategy1_succeeded = True |
| except Exception as e: |
| warn(f"Strategy 1 failed: {e}") |
| traceback.print_exc() |
|
|
| if not tflite_file: |
| try: |
| tflite_file = strategy2_generic( |
| model_dir=model_dir, |
| out_dir=tflite_dir, |
| prefill=args.prefill_seq_len, |
| quantize=q, |
| ) |
| if tflite_file: |
| log(f"Strategy 2 success: {tflite_file}") |
| warn("Generic TFLite may not have MediaPipe LLM prefill/decode signatures.") |
| except Exception as e: |
| warn(f"Strategy 2 failed: {e}") |
| traceback.print_exc() |
|
|
| |
| |
| if tflite_file and not strategy1_succeeded and q != "none": |
| try: |
| quantized = strategy3_post_tflite_quantize(tflite_file, tflite_dir, q) |
| if quantized: |
| tflite_file = quantized |
| log(f"Strategy 3 success: {tflite_file}") |
| else: |
| warn("Strategy 3 skipped; bundling unquantized model") |
| except Exception as e: |
| warn(f"Strategy 3 failed: {e}") |
| traceback.print_exc() |
|
|
| if not tflite_file: |
| die("All conversion strategies failed") |
|
|
| tokenizer_model = ensure_tokenizer_model(model_dir) |
| bundle_task(tflite_file, tokenizer_model, task_file) |
|
|
| log(f"DONE: {task_file}") |
| if task_file.exists(): |
| log(f"Size: {task_file.stat().st_size / (1024 * 1024):.2f} MB") |
|
|
|
|
| if __name__ == "__main__": |
| main() |