#!/usr/bin/env python3 import argparse import importlib import inspect import json import os import subprocess import sys import traceback from pathlib import Path os.environ.setdefault("TRANSFORMERS_NO_TORCHVISION", "1") os.environ.setdefault("PYTHONUNBUFFERED", "1") def log(msg): print(f"[+] {msg}", flush=True) def warn(msg): print(f"[!] {msg}", flush=True) def die(msg): print(f"[x] {msg}", file=sys.stderr, flush=True) sys.exit(1) SUPPORTED_QUANT = { "none", "dynamic_int8", "float16", "int8", "int4", } def normalize_quantize(q: str) -> str: q = (q or "dynamic_int8").strip().lower() aliases = { "fp32": "none", "no": "none", "off": "none", "fp16": "float16", "f16": "float16", "i8": "int8", "q8": "int8", "i4": "int4", "q4": "int4", } q = aliases.get(q, q) if q not in SUPPORTED_QUANT: die(f"Unsupported --quantize '{q}'. Supported: {sorted(SUPPORTED_QUANT)}") return q def load_config(model_dir: Path): cfg_path = model_dir / "config.json" if not cfg_path.exists(): die(f"Missing config: {cfg_path}") cfg = json.loads(cfg_path.read_text()) text_cfg = cfg.get("text_config", {}) return cfg, text_cfg def inspect_arch(model_dir: Path): cfg, text_cfg = load_config(model_dir) info = { "model_type": cfg.get("model_type", "unknown"), "architecture": (cfg.get("architectures") or ["unknown"])[0], "vocab_size": cfg.get("vocab_size", text_cfg.get("vocab_size", 262144)), } log(f"ARCH: {info}") return info def ensure_model_downloaded(model_id: str, model_dir: Path, hf_token: str): if (model_dir / "config.json").exists(): log(f"Using existing model dir: {model_dir}") return log(f"Downloading {model_id} -> {model_dir}") try: from huggingface_hub import snapshot_download snapshot_download( repo_id=model_id, local_dir=str(model_dir), token=hf_token if hf_token else None, local_dir_use_symlinks=False, ) except Exception as e: die(f"Model download failed: {e}") def try_builders(mod, builder_names, model_dir: Path): for fn_name in builder_names: fn = getattr(mod, fn_name, None) if fn is None: continue log(f"Trying {mod.__name__}.{fn_name} ...") try: m = fn(str(model_dir)) if m is None: warn(" returned None") continue if isinstance(m, (tuple, list)) and len(m) > 0: m = m[0] if not hasattr(m, "eval"): warn(f" unsupported return type: {type(m)}") continue m.eval() log(f" success with {fn_name}") return m except Exception as e: warn(f" failed: {e}") return None def build_translategemma_4b(checkpoint_path: str): """ Custom builder for TranslateGemma 4B IT (Gemma3 multimodal decoder). Strips 'language_model.' prefix from safetensors keys so the standard TENSOR_NAMES_SEP_QKV mapping works. Architecture (from config.json / verified weight shapes): 34 layers, embedding_dim=2560, 8 heads, head_dim=256, 4 KV heads, intermediate=10240, sliding_window=1024, global every 6th layer. """ import safetensors.torch as st_lib import json as json_lib from litert_torch.generative.utilities import model_builder, loader as loading_utils from litert_torch.generative.layers import kv_cache as kv_utils from litert_torch.generative.examples.gemma3 import decoder as gemma3_decoder import litert_torch.generative.layers.model_config as cfg_mod norm = cfg_mod.NormalizationConfig( type=cfg_mod.NormalizationType.RMS_NORM, epsilon=1e-6, zero_centered=True ) ff = cfg_mod.FeedForwardConfig( type=cfg_mod.FeedForwardType.GATED, activation=cfg_mod.ActivationConfig(cfg_mod.ActivationType.GELU_TANH), intermediate_size=10240, pre_ff_norm_config=norm, post_ff_norm_config=norm, ) def blk(idx): attn = cfg_mod.AttentionConfig( num_heads=8, head_dim=256, num_query_groups=4, rotary_base=1_000_000 if (idx + 1) % 6 == 0 else 10_000, rotary_percentage=1.0, qkv_transpose_before_split=True, qkv_fused_interleaved=False, query_norm_config=norm, key_norm_config=norm, logit_softcap=None, sliding_window_size=1024, attn_type=cfg_mod.AttentionType.GLOBAL if (idx + 1) % 6 == 0 else cfg_mod.AttentionType.LOCAL_SLIDING, ) return cfg_mod.TransformerBlockConfig( attn_config=attn, ff_config=ff, pre_attention_norm_config=norm, post_attention_norm_config=norm, ) model_config = cfg_mod.ModelConfig( vocab_size=262208, num_layers=34, max_seq_len=8192, embedding_dim=2560, embedding_scale=2560 ** 0.5, block_configs=[blk(i) for i in range(34)], final_norm_config=norm, lm_head_use_bias=False, final_logit_softcap=None, ) tensor_names = loading_utils.ModelLoader.TensorNames( ff_up_proj="model.layers.{}.mlp.up_proj", ff_down_proj="model.layers.{}.mlp.down_proj", ff_gate_proj="model.layers.{}.mlp.gate_proj", attn_query_proj="model.layers.{}.self_attn.q_proj", attn_key_proj="model.layers.{}.self_attn.k_proj", attn_value_proj="model.layers.{}.self_attn.v_proj", attn_output_proj="model.layers.{}.self_attn.o_proj", attn_query_norm="model.layers.{}.self_attn.q_norm", attn_key_norm="model.layers.{}.self_attn.k_norm", pre_attn_norm="model.layers.{}.input_layernorm", post_attn_norm="model.layers.{}.post_attention_layernorm", pre_ff_norm="model.layers.{}.pre_feedforward_layernorm", post_ff_norm="model.layers.{}.post_feedforward_layernorm", embedding="model.embed_tokens", final_norm="model.norm", lm_head=None, ) def custom_loader(path: str): idx = json_lib.loads((Path(path) / "model.safetensors.index.json").read_text()) out = {} for fname in set(idx["weight_map"].values()): for k, v in st_lib.load_file(str(Path(path) / fname)).items(): out[k[len("language_model."):] if k.startswith("language_model.") else k] = v return out return model_builder.build_decoder_only_model( checkpoint_path=checkpoint_path, config=model_config, tensor_names=tensor_names, model_class=gemma3_decoder.Decoder, custom_loader=custom_loader, ) def strategy1_litert_native(model_dir: Path, out_dir: Path, model_type: str, quantize: str, prefill: int, kvcache: int): """ Native LiteRT conversion with explicit quantization support. """ log("Strategy 1: litert-torch native") from litert_torch.generative.utilities import converter from litert_torch.generative.utilities.export_config import ExportConfig from litert_torch.generative.layers import kv_cache export_config = ExportConfig() export_config.kvcache_layout = kv_cache.KV_LAYOUT_TRANSPOSED export_config.mask_as_input = True # Map our quantize flags to converter's QuantizationName values QUANT_MAP = { "none": "none", "dynamic_int8": "dynamic_int8", "int8": "weight_only_int8", "float16": "fp16", "int4": "dynamic_int4_block128", } quant_for_converter = QUANT_MAP.get(quantize) if quant_for_converter is None: warn(f"No converter mapping for '{quantize}', falling back to Strategy 2") return None model = None if model_type == "gemma3": # Try custom 4B builder first (handles TranslateGemma 4B multimodal weight prefix) log("Trying custom build_translategemma_4b ...") try: model = build_translategemma_4b(str(model_dir)) if model is not None: model.eval() log(" build_translategemma_4b success") except Exception as e: warn(f" build_translategemma_4b failed: {e}") model = None if model is None: mod = importlib.import_module("litert_torch.generative.examples.gemma3.gemma3") available = [n for n in dir(mod) if n.startswith("build_model")] log(f"Gemma3 builders available: {available}") preferred = ["build_model_4b", "build_model_2b", "build_model_1b", "build_model_270m", "build_model"] ordered = [n for n in preferred if n in available] + [n for n in available if n not in preferred] model = try_builders(mod, ordered, model_dir) elif model_type in ("gemma", "gemma2"): mod = importlib.import_module("litert_torch.generative.examples.gemma2.gemma2") available = [n for n in dir(mod) if n.startswith("build_model")] log(f"Gemma2 builders available: {available}") preferred = ["build_model_4b", "build_model_2b", "build_model"] ordered = [n for n in preferred if n in available] + [n for n in available if n not in preferred] model = try_builders(mod, ordered, model_dir) else: warn(f"Model type '{model_type}' not handled by native strategy") return None if model is None: warn("Strategy 1 did not find a compatible builder") return None converter.convert_to_tflite( model, output_path=str(out_dir), output_name_prefix=f"translategemma-4b-it-{quantize}", prefill_seq_len=prefill, kv_cache_max_len=kvcache, quantize=quant_for_converter, export_config=export_config, ) produced = sorted(out_dir.glob(f"*{quantize}*.tflite")) or sorted(out_dir.glob("*.tflite")) return produced[0] if produced else None def strategy2_generic(model_dir: Path, out_dir: Path, prefill: int, quantize: str): """ Generic fallback conversion (logits-only). Always exports float32; use strategy3_post_tflite_quantize() afterwards for real int4/int8 compression. """ log("Strategy 2: ai_edge_torch generic (wrapped logits-only)") import torch try: import litert_torch as ai_edge_torch except Exception: import ai_edge_torch # deprecated fallback from transformers import AutoConfig, AutoModelForCausalLM dtype = torch.float32 if quantize == "float16": dtype = torch.float16 cfg = AutoConfig.from_pretrained(str(model_dir), trust_remote_code=True) vocab = getattr(cfg, "vocab_size", None) if vocab is None and hasattr(cfg, "text_config"): vocab = getattr(cfg.text_config, "vocab_size", None) vocab = int(vocab or 262144) log(f"Loading HF model on CPU with dtype={dtype} ...") base_model = AutoModelForCausalLM.from_pretrained( str(model_dir), trust_remote_code=True, torch_dtype=dtype, ) base_model.eval() # TFLite embedding_lookup does not support f16 weights — keep embedding in float32 if dtype == torch.float16: log("Casting embedding and lm_head to float32 for TFLite compatibility") if hasattr(base_model, "model") and hasattr(base_model.model, "embed_tokens"): base_model.model.embed_tokens = base_model.model.embed_tokens.to(torch.float32) if hasattr(base_model, "lm_head"): base_model.lm_head = base_model.lm_head.to(torch.float32) if hasattr(base_model, "config") and hasattr(base_model.config, "use_cache"): base_model.config.use_cache = False class LogitsOnlyWrapper(torch.nn.Module): def __init__(self, model): super().__init__() self.model = model def forward(self, input_ids): out = self.model( input_ids=input_ids, use_cache=False, return_dict=False, ) logits = out[0] if isinstance(out, (tuple, list)) else out.logits return logits wrapped = LogitsOnlyWrapper(base_model).eval() sample_ids = torch.randint(0, vocab, (1, min(prefill, 128)), dtype=torch.int64) # Always export float32 base; int4/int8 quantization applied post-export via Strategy 3 out_file = out_dir / f"translategemma-4b-it-generic-none.tflite" edge_model = ai_edge_torch.convert(wrapped, (sample_ids,)) edge_model.export(str(out_file)) return out_file if out_file.exists() else None def strategy3_post_tflite_quantize(tflite_in: Path, out_dir: Path, quantize: str): """ Post-conversion weight quantization applied directly to a TFLite flatbuffer using ai_edge_quantizer (bundled with litert_torch). Supported modes and their recipes (per get_supported_layer_schemes()): int4 -> INT4 DYNAMIC_RANGE BLOCKWISE_128 (~2 GB for 4B model) int8 -> INT8 WEIGHT_ONLY CHANNELWISE (~4 GB) dynamic_int8 -> INT8 DYNAMIC_RANGE CHANNELWISE (~4 GB) float16 -> FP16 WEIGHT_ONLY FLOAT_CAST (~8 GB) """ log(f"Strategy 3: post-TFLite quantization ({quantize}) on {tflite_in.name}") from litert_torch.generative.quantize import quant_attrs as qa, quant_recipe, quant_recipe_utils from litert_torch.quantize import translate_recipe if quantize == "int4": layer = quant_recipe_utils.create_layer_quant_dynamic(qa.Dtype.INT4, qa.Granularity.BLOCKWISE_128) elif quantize == "int8": layer = quant_recipe_utils.create_layer_quant_weight_only(qa.Dtype.INT8, qa.Granularity.CHANNELWISE) elif quantize == "dynamic_int8": layer = quant_recipe_utils.create_layer_quant_dynamic(qa.Dtype.INT8, qa.Granularity.CHANNELWISE) elif quantize == "float16": layer = quant_recipe_utils.create_layer_quant_fp16() else: warn(f"Strategy 3: no post-TFLite recipe for '{quantize}', skipping") return None gen_recipe = quant_recipe.GenerativeQuantRecipe(default=layer) ai_recipe = translate_recipe.translate_to_ai_edge_recipe(gen_recipe) model_bytes = tflite_in.read_bytes() log(f" Input: {len(model_bytes) / 1024**3:.2f} GB — quantizing ...") quantized_bytes = translate_recipe.quantize_model(model_bytes, ai_recipe) out_file = out_dir / f"translategemma-4b-it-{quantize}.tflite" out_file.write_bytes(quantized_bytes) log(f" Output: {len(quantized_bytes) / 1024**3:.2f} GB -> {out_file}") return out_file def ensure_tokenizer_model(model_dir: Path): tok_model = model_dir / "tokenizer.model" if tok_model.exists(): return tok_model tok_json = model_dir / "tokenizer.json" if not tok_json.exists(): die(f"Missing tokenizer files: neither {tok_model} nor {tok_json} exists") log("Converting tokenizer.json -> tokenizer.model") cmd = [ sys.executable, "-m", "litert_torch.generative.tools.tokenizer_to_sentencepiece", f"--checkpoint={model_dir}", f"--output_path={tok_model}", ] res = subprocess.run(cmd, text=True, capture_output=True) if res.stdout: print(res.stdout, flush=True) if res.returncode != 0: if res.stderr: print(res.stderr, file=sys.stderr, flush=True) die("Tokenizer conversion failed") if not tok_model.exists(): die("Tokenizer conversion reported success but tokenizer.model is missing") return tok_model def bundle_task(tflite_file: Path, tokenizer_model: Path, task_file: Path): log(f"Bundling .task -> {task_file}") from mediapipe.tasks.python.genai import bundler task_file.parent.mkdir(parents=True, exist_ok=True) sig = inspect.signature(bundler.BundleConfig) params = sig.parameters log(f"BundleConfig params: {list(params.keys())}") kwargs = { "tflite_model": str(tflite_file), "tokenizer_model": str(tokenizer_model), "output_filename": str(task_file), } if "start_token" in params: kwargs["start_token"] = "" elif "start_tokens" in params: kwargs["start_tokens"] = [""] if "stop_tokens" in params: kwargs["stop_tokens"] = [""] if "prompt_prefix" in params: kwargs["prompt_prefix"] = "" if "prompt_suffix" in params: kwargs["prompt_suffix"] = "" kwargs = {k: v for k, v in kwargs.items() if k in params} cfg = bundler.BundleConfig(**kwargs) bundler.create_bundle(cfg) def main(): ap = argparse.ArgumentParser(description="TranslateGemma -> Android .task converter") ap.add_argument("--model-id", default="google/translategemma-4b-it") ap.add_argument("--model-dir", default="./translategemma-4b-it") ap.add_argument("--tflite-dir", default="./tflite_output") ap.add_argument("--output-dir", default="./output") ap.add_argument("--task-file", default="./output/translategemma-4b-it-android.task") ap.add_argument("--quantize", default="dynamic_int8", help=f"One of: {sorted(SUPPORTED_QUANT)}") ap.add_argument("--prefill-seq-len", "--prefill", dest="prefill_seq_len", type=int, default=1024) ap.add_argument("--kv-cache-max-len", "--kvcache", dest="kv_cache_max_len", type=int, default=1024) ap.add_argument("--skip-strategy1", action="store_true") ap.add_argument("--bundle-only", action="store_true", help="Skip conversion; only bundle existing TFLite") ap.add_argument("--existing-tflite", default="", help="Path to an existing .tflite to bundle") ap.add_argument("--allow-no-token", action="store_true", help="Allow model download from public repo without HF_TOKEN") args = ap.parse_args() q = normalize_quantize(args.quantize) hf_token = os.environ.get("HF_TOKEN", "").strip() if not hf_token and not args.allow_no_token: warn("HF_TOKEN is not set. If model is public, you can pass --allow-no-token.") model_dir = Path(args.model_dir) tflite_dir = Path(args.tflite_dir) output_dir = Path(args.output_dir) task_file = Path(args.task_file) tflite_dir.mkdir(parents=True, exist_ok=True) output_dir.mkdir(parents=True, exist_ok=True) tflite_file = None if args.bundle_only: if not args.existing_tflite: die("--bundle-only requires --existing-tflite /path/to/model.tflite") tflite_file = Path(args.existing_tflite) if not tflite_file.exists(): die(f"Existing tflite not found: {tflite_file}") log(f"Bundle-only mode using: {tflite_file}") # Apply post-TFLite quantization if requested if q != "none": try: quantized = strategy3_post_tflite_quantize(tflite_file, tflite_dir, q) if quantized: tflite_file = quantized log(f"Strategy 3 success: {tflite_file}") else: warn("Strategy 3 skipped; bundling input as-is") except Exception as e: warn(f"Strategy 3 failed: {e}") traceback.print_exc() else: ensure_model_downloaded(args.model_id, model_dir, hf_token if hf_token else "") arch = inspect_arch(model_dir) model_type = arch["model_type"] strategy1_succeeded = False if not args.skip_strategy1: try: tflite_file = strategy1_litert_native( model_dir=model_dir, out_dir=tflite_dir, model_type=model_type, quantize=q, prefill=args.prefill_seq_len, kvcache=args.kv_cache_max_len, ) if tflite_file: log(f"Strategy 1 success: {tflite_file}") strategy1_succeeded = True except Exception as e: warn(f"Strategy 1 failed: {e}") traceback.print_exc() if not tflite_file: try: tflite_file = strategy2_generic( model_dir=model_dir, out_dir=tflite_dir, prefill=args.prefill_seq_len, quantize=q, ) if tflite_file: log(f"Strategy 2 success: {tflite_file}") warn("Generic TFLite may not have MediaPipe LLM prefill/decode signatures.") except Exception as e: warn(f"Strategy 2 failed: {e}") traceback.print_exc() # Strategy 3: post-TFLite quantization — only when Strategy 2 was used (Strategy 1 # already applies quantization natively via the converter). if tflite_file and not strategy1_succeeded and q != "none": try: quantized = strategy3_post_tflite_quantize(tflite_file, tflite_dir, q) if quantized: tflite_file = quantized log(f"Strategy 3 success: {tflite_file}") else: warn("Strategy 3 skipped; bundling unquantized model") except Exception as e: warn(f"Strategy 3 failed: {e}") traceback.print_exc() if not tflite_file: die("All conversion strategies failed") tokenizer_model = ensure_tokenizer_model(model_dir) bundle_task(tflite_file, tokenizer_model, task_file) log(f"DONE: {task_file}") if task_file.exists(): log(f"Size: {task_file.stat().st_size / (1024 * 1024):.2f} MB") if __name__ == "__main__": main()