translategemma-4b-it-android-task-quantized / scripts /convert_translategemma_android.py
barakplasma's picture
Upload scripts/convert_translategemma_android.py with huggingface_hub
7f4c0e9 verified
#!/usr/bin/env python3
import argparse
import importlib
import inspect
import json
import os
import subprocess
import sys
import traceback
from pathlib import Path
os.environ.setdefault("TRANSFORMERS_NO_TORCHVISION", "1")
os.environ.setdefault("PYTHONUNBUFFERED", "1")
def log(msg): print(f"[+] {msg}", flush=True)
def warn(msg): print(f"[!] {msg}", flush=True)
def die(msg):
print(f"[x] {msg}", file=sys.stderr, flush=True)
sys.exit(1)
SUPPORTED_QUANT = {
"none",
"dynamic_int8",
"float16",
"int8",
"int4",
}
def normalize_quantize(q: str) -> str:
q = (q or "dynamic_int8").strip().lower()
aliases = {
"fp32": "none",
"no": "none",
"off": "none",
"fp16": "float16",
"f16": "float16",
"i8": "int8",
"q8": "int8",
"i4": "int4",
"q4": "int4",
}
q = aliases.get(q, q)
if q not in SUPPORTED_QUANT:
die(f"Unsupported --quantize '{q}'. Supported: {sorted(SUPPORTED_QUANT)}")
return q
def load_config(model_dir: Path):
cfg_path = model_dir / "config.json"
if not cfg_path.exists():
die(f"Missing config: {cfg_path}")
cfg = json.loads(cfg_path.read_text())
text_cfg = cfg.get("text_config", {})
return cfg, text_cfg
def inspect_arch(model_dir: Path):
cfg, text_cfg = load_config(model_dir)
info = {
"model_type": cfg.get("model_type", "unknown"),
"architecture": (cfg.get("architectures") or ["unknown"])[0],
"vocab_size": cfg.get("vocab_size", text_cfg.get("vocab_size", 262144)),
}
log(f"ARCH: {info}")
return info
def ensure_model_downloaded(model_id: str, model_dir: Path, hf_token: str):
if (model_dir / "config.json").exists():
log(f"Using existing model dir: {model_dir}")
return
log(f"Downloading {model_id} -> {model_dir}")
try:
from huggingface_hub import snapshot_download
snapshot_download(
repo_id=model_id,
local_dir=str(model_dir),
token=hf_token if hf_token else None,
local_dir_use_symlinks=False,
)
except Exception as e:
die(f"Model download failed: {e}")
def try_builders(mod, builder_names, model_dir: Path):
for fn_name in builder_names:
fn = getattr(mod, fn_name, None)
if fn is None:
continue
log(f"Trying {mod.__name__}.{fn_name} ...")
try:
m = fn(str(model_dir))
if m is None:
warn(" returned None")
continue
if isinstance(m, (tuple, list)) and len(m) > 0:
m = m[0]
if not hasattr(m, "eval"):
warn(f" unsupported return type: {type(m)}")
continue
m.eval()
log(f" success with {fn_name}")
return m
except Exception as e:
warn(f" failed: {e}")
return None
def build_translategemma_4b(checkpoint_path: str):
"""
Custom builder for TranslateGemma 4B IT (Gemma3 multimodal decoder).
Strips 'language_model.' prefix from safetensors keys so the standard
TENSOR_NAMES_SEP_QKV mapping works.
Architecture (from config.json / verified weight shapes):
34 layers, embedding_dim=2560, 8 heads, head_dim=256, 4 KV heads,
intermediate=10240, sliding_window=1024, global every 6th layer.
"""
import safetensors.torch as st_lib
import json as json_lib
from litert_torch.generative.utilities import model_builder, loader as loading_utils
from litert_torch.generative.layers import kv_cache as kv_utils
from litert_torch.generative.examples.gemma3 import decoder as gemma3_decoder
import litert_torch.generative.layers.model_config as cfg_mod
norm = cfg_mod.NormalizationConfig(
type=cfg_mod.NormalizationType.RMS_NORM, epsilon=1e-6, zero_centered=True
)
ff = cfg_mod.FeedForwardConfig(
type=cfg_mod.FeedForwardType.GATED,
activation=cfg_mod.ActivationConfig(cfg_mod.ActivationType.GELU_TANH),
intermediate_size=10240,
pre_ff_norm_config=norm, post_ff_norm_config=norm,
)
def blk(idx):
attn = cfg_mod.AttentionConfig(
num_heads=8, head_dim=256, num_query_groups=4,
rotary_base=1_000_000 if (idx + 1) % 6 == 0 else 10_000,
rotary_percentage=1.0, qkv_transpose_before_split=True,
qkv_fused_interleaved=False,
query_norm_config=norm, key_norm_config=norm, logit_softcap=None,
sliding_window_size=1024,
attn_type=cfg_mod.AttentionType.GLOBAL if (idx + 1) % 6 == 0
else cfg_mod.AttentionType.LOCAL_SLIDING,
)
return cfg_mod.TransformerBlockConfig(
attn_config=attn, ff_config=ff,
pre_attention_norm_config=norm, post_attention_norm_config=norm,
)
model_config = cfg_mod.ModelConfig(
vocab_size=262208, num_layers=34, max_seq_len=8192,
embedding_dim=2560, embedding_scale=2560 ** 0.5,
block_configs=[blk(i) for i in range(34)],
final_norm_config=norm, lm_head_use_bias=False, final_logit_softcap=None,
)
tensor_names = loading_utils.ModelLoader.TensorNames(
ff_up_proj="model.layers.{}.mlp.up_proj",
ff_down_proj="model.layers.{}.mlp.down_proj",
ff_gate_proj="model.layers.{}.mlp.gate_proj",
attn_query_proj="model.layers.{}.self_attn.q_proj",
attn_key_proj="model.layers.{}.self_attn.k_proj",
attn_value_proj="model.layers.{}.self_attn.v_proj",
attn_output_proj="model.layers.{}.self_attn.o_proj",
attn_query_norm="model.layers.{}.self_attn.q_norm",
attn_key_norm="model.layers.{}.self_attn.k_norm",
pre_attn_norm="model.layers.{}.input_layernorm",
post_attn_norm="model.layers.{}.post_attention_layernorm",
pre_ff_norm="model.layers.{}.pre_feedforward_layernorm",
post_ff_norm="model.layers.{}.post_feedforward_layernorm",
embedding="model.embed_tokens",
final_norm="model.norm",
lm_head=None,
)
def custom_loader(path: str):
idx = json_lib.loads((Path(path) / "model.safetensors.index.json").read_text())
out = {}
for fname in set(idx["weight_map"].values()):
for k, v in st_lib.load_file(str(Path(path) / fname)).items():
out[k[len("language_model."):] if k.startswith("language_model.") else k] = v
return out
return model_builder.build_decoder_only_model(
checkpoint_path=checkpoint_path,
config=model_config,
tensor_names=tensor_names,
model_class=gemma3_decoder.Decoder,
custom_loader=custom_loader,
)
def strategy1_litert_native(model_dir: Path, out_dir: Path, model_type: str, quantize: str, prefill: int, kvcache: int):
"""
Native LiteRT conversion with explicit quantization support.
"""
log("Strategy 1: litert-torch native")
from litert_torch.generative.utilities import converter
from litert_torch.generative.utilities.export_config import ExportConfig
from litert_torch.generative.layers import kv_cache
export_config = ExportConfig()
export_config.kvcache_layout = kv_cache.KV_LAYOUT_TRANSPOSED
export_config.mask_as_input = True
# Map our quantize flags to converter's QuantizationName values
QUANT_MAP = {
"none": "none",
"dynamic_int8": "dynamic_int8",
"int8": "weight_only_int8",
"float16": "fp16",
"int4": "dynamic_int4_block128",
}
quant_for_converter = QUANT_MAP.get(quantize)
if quant_for_converter is None:
warn(f"No converter mapping for '{quantize}', falling back to Strategy 2")
return None
model = None
if model_type == "gemma3":
# Try custom 4B builder first (handles TranslateGemma 4B multimodal weight prefix)
log("Trying custom build_translategemma_4b ...")
try:
model = build_translategemma_4b(str(model_dir))
if model is not None:
model.eval()
log(" build_translategemma_4b success")
except Exception as e:
warn(f" build_translategemma_4b failed: {e}")
model = None
if model is None:
mod = importlib.import_module("litert_torch.generative.examples.gemma3.gemma3")
available = [n for n in dir(mod) if n.startswith("build_model")]
log(f"Gemma3 builders available: {available}")
preferred = ["build_model_4b", "build_model_2b", "build_model_1b", "build_model_270m", "build_model"]
ordered = [n for n in preferred if n in available] + [n for n in available if n not in preferred]
model = try_builders(mod, ordered, model_dir)
elif model_type in ("gemma", "gemma2"):
mod = importlib.import_module("litert_torch.generative.examples.gemma2.gemma2")
available = [n for n in dir(mod) if n.startswith("build_model")]
log(f"Gemma2 builders available: {available}")
preferred = ["build_model_4b", "build_model_2b", "build_model"]
ordered = [n for n in preferred if n in available] + [n for n in available if n not in preferred]
model = try_builders(mod, ordered, model_dir)
else:
warn(f"Model type '{model_type}' not handled by native strategy")
return None
if model is None:
warn("Strategy 1 did not find a compatible builder")
return None
converter.convert_to_tflite(
model,
output_path=str(out_dir),
output_name_prefix=f"translategemma-4b-it-{quantize}",
prefill_seq_len=prefill,
kv_cache_max_len=kvcache,
quantize=quant_for_converter,
export_config=export_config,
)
produced = sorted(out_dir.glob(f"*{quantize}*.tflite")) or sorted(out_dir.glob("*.tflite"))
return produced[0] if produced else None
def strategy2_generic(model_dir: Path, out_dir: Path, prefill: int, quantize: str):
"""
Generic fallback conversion (logits-only). Always exports float32; use
strategy3_post_tflite_quantize() afterwards for real int4/int8 compression.
"""
log("Strategy 2: ai_edge_torch generic (wrapped logits-only)")
import torch
try:
import litert_torch as ai_edge_torch
except Exception:
import ai_edge_torch # deprecated fallback
from transformers import AutoConfig, AutoModelForCausalLM
dtype = torch.float32
if quantize == "float16":
dtype = torch.float16
cfg = AutoConfig.from_pretrained(str(model_dir), trust_remote_code=True)
vocab = getattr(cfg, "vocab_size", None)
if vocab is None and hasattr(cfg, "text_config"):
vocab = getattr(cfg.text_config, "vocab_size", None)
vocab = int(vocab or 262144)
log(f"Loading HF model on CPU with dtype={dtype} ...")
base_model = AutoModelForCausalLM.from_pretrained(
str(model_dir),
trust_remote_code=True,
torch_dtype=dtype,
)
base_model.eval()
# TFLite embedding_lookup does not support f16 weights — keep embedding in float32
if dtype == torch.float16:
log("Casting embedding and lm_head to float32 for TFLite compatibility")
if hasattr(base_model, "model") and hasattr(base_model.model, "embed_tokens"):
base_model.model.embed_tokens = base_model.model.embed_tokens.to(torch.float32)
if hasattr(base_model, "lm_head"):
base_model.lm_head = base_model.lm_head.to(torch.float32)
if hasattr(base_model, "config") and hasattr(base_model.config, "use_cache"):
base_model.config.use_cache = False
class LogitsOnlyWrapper(torch.nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
def forward(self, input_ids):
out = self.model(
input_ids=input_ids,
use_cache=False,
return_dict=False,
)
logits = out[0] if isinstance(out, (tuple, list)) else out.logits
return logits
wrapped = LogitsOnlyWrapper(base_model).eval()
sample_ids = torch.randint(0, vocab, (1, min(prefill, 128)), dtype=torch.int64)
# Always export float32 base; int4/int8 quantization applied post-export via Strategy 3
out_file = out_dir / f"translategemma-4b-it-generic-none.tflite"
edge_model = ai_edge_torch.convert(wrapped, (sample_ids,))
edge_model.export(str(out_file))
return out_file if out_file.exists() else None
def strategy3_post_tflite_quantize(tflite_in: Path, out_dir: Path, quantize: str):
"""
Post-conversion weight quantization applied directly to a TFLite flatbuffer
using ai_edge_quantizer (bundled with litert_torch).
Supported modes and their recipes (per get_supported_layer_schemes()):
int4 -> INT4 DYNAMIC_RANGE BLOCKWISE_128 (~2 GB for 4B model)
int8 -> INT8 WEIGHT_ONLY CHANNELWISE (~4 GB)
dynamic_int8 -> INT8 DYNAMIC_RANGE CHANNELWISE (~4 GB)
float16 -> FP16 WEIGHT_ONLY FLOAT_CAST (~8 GB)
"""
log(f"Strategy 3: post-TFLite quantization ({quantize}) on {tflite_in.name}")
from litert_torch.generative.quantize import quant_attrs as qa, quant_recipe, quant_recipe_utils
from litert_torch.quantize import translate_recipe
if quantize == "int4":
layer = quant_recipe_utils.create_layer_quant_dynamic(qa.Dtype.INT4, qa.Granularity.BLOCKWISE_128)
elif quantize == "int8":
layer = quant_recipe_utils.create_layer_quant_weight_only(qa.Dtype.INT8, qa.Granularity.CHANNELWISE)
elif quantize == "dynamic_int8":
layer = quant_recipe_utils.create_layer_quant_dynamic(qa.Dtype.INT8, qa.Granularity.CHANNELWISE)
elif quantize == "float16":
layer = quant_recipe_utils.create_layer_quant_fp16()
else:
warn(f"Strategy 3: no post-TFLite recipe for '{quantize}', skipping")
return None
gen_recipe = quant_recipe.GenerativeQuantRecipe(default=layer)
ai_recipe = translate_recipe.translate_to_ai_edge_recipe(gen_recipe)
model_bytes = tflite_in.read_bytes()
log(f" Input: {len(model_bytes) / 1024**3:.2f} GB — quantizing ...")
quantized_bytes = translate_recipe.quantize_model(model_bytes, ai_recipe)
out_file = out_dir / f"translategemma-4b-it-{quantize}.tflite"
out_file.write_bytes(quantized_bytes)
log(f" Output: {len(quantized_bytes) / 1024**3:.2f} GB -> {out_file}")
return out_file
def ensure_tokenizer_model(model_dir: Path):
tok_model = model_dir / "tokenizer.model"
if tok_model.exists():
return tok_model
tok_json = model_dir / "tokenizer.json"
if not tok_json.exists():
die(f"Missing tokenizer files: neither {tok_model} nor {tok_json} exists")
log("Converting tokenizer.json -> tokenizer.model")
cmd = [
sys.executable,
"-m",
"litert_torch.generative.tools.tokenizer_to_sentencepiece",
f"--checkpoint={model_dir}",
f"--output_path={tok_model}",
]
res = subprocess.run(cmd, text=True, capture_output=True)
if res.stdout:
print(res.stdout, flush=True)
if res.returncode != 0:
if res.stderr:
print(res.stderr, file=sys.stderr, flush=True)
die("Tokenizer conversion failed")
if not tok_model.exists():
die("Tokenizer conversion reported success but tokenizer.model is missing")
return tok_model
def bundle_task(tflite_file: Path, tokenizer_model: Path, task_file: Path):
log(f"Bundling .task -> {task_file}")
from mediapipe.tasks.python.genai import bundler
task_file.parent.mkdir(parents=True, exist_ok=True)
sig = inspect.signature(bundler.BundleConfig)
params = sig.parameters
log(f"BundleConfig params: {list(params.keys())}")
kwargs = {
"tflite_model": str(tflite_file),
"tokenizer_model": str(tokenizer_model),
"output_filename": str(task_file),
}
if "start_token" in params:
kwargs["start_token"] = "<bos>"
elif "start_tokens" in params:
kwargs["start_tokens"] = ["<bos>"]
if "stop_tokens" in params:
kwargs["stop_tokens"] = ["<eos>"]
if "prompt_prefix" in params:
kwargs["prompt_prefix"] = ""
if "prompt_suffix" in params:
kwargs["prompt_suffix"] = ""
kwargs = {k: v for k, v in kwargs.items() if k in params}
cfg = bundler.BundleConfig(**kwargs)
bundler.create_bundle(cfg)
def main():
ap = argparse.ArgumentParser(description="TranslateGemma -> Android .task converter")
ap.add_argument("--model-id", default="google/translategemma-4b-it")
ap.add_argument("--model-dir", default="./translategemma-4b-it")
ap.add_argument("--tflite-dir", default="./tflite_output")
ap.add_argument("--output-dir", default="./output")
ap.add_argument("--task-file", default="./output/translategemma-4b-it-android.task")
ap.add_argument("--quantize", default="dynamic_int8", help=f"One of: {sorted(SUPPORTED_QUANT)}")
ap.add_argument("--prefill-seq-len", "--prefill", dest="prefill_seq_len", type=int, default=1024)
ap.add_argument("--kv-cache-max-len", "--kvcache", dest="kv_cache_max_len", type=int, default=1024)
ap.add_argument("--skip-strategy1", action="store_true")
ap.add_argument("--bundle-only", action="store_true", help="Skip conversion; only bundle existing TFLite")
ap.add_argument("--existing-tflite", default="", help="Path to an existing .tflite to bundle")
ap.add_argument("--allow-no-token", action="store_true", help="Allow model download from public repo without HF_TOKEN")
args = ap.parse_args()
q = normalize_quantize(args.quantize)
hf_token = os.environ.get("HF_TOKEN", "").strip()
if not hf_token and not args.allow_no_token:
warn("HF_TOKEN is not set. If model is public, you can pass --allow-no-token.")
model_dir = Path(args.model_dir)
tflite_dir = Path(args.tflite_dir)
output_dir = Path(args.output_dir)
task_file = Path(args.task_file)
tflite_dir.mkdir(parents=True, exist_ok=True)
output_dir.mkdir(parents=True, exist_ok=True)
tflite_file = None
if args.bundle_only:
if not args.existing_tflite:
die("--bundle-only requires --existing-tflite /path/to/model.tflite")
tflite_file = Path(args.existing_tflite)
if not tflite_file.exists():
die(f"Existing tflite not found: {tflite_file}")
log(f"Bundle-only mode using: {tflite_file}")
# Apply post-TFLite quantization if requested
if q != "none":
try:
quantized = strategy3_post_tflite_quantize(tflite_file, tflite_dir, q)
if quantized:
tflite_file = quantized
log(f"Strategy 3 success: {tflite_file}")
else:
warn("Strategy 3 skipped; bundling input as-is")
except Exception as e:
warn(f"Strategy 3 failed: {e}")
traceback.print_exc()
else:
ensure_model_downloaded(args.model_id, model_dir, hf_token if hf_token else "")
arch = inspect_arch(model_dir)
model_type = arch["model_type"]
strategy1_succeeded = False
if not args.skip_strategy1:
try:
tflite_file = strategy1_litert_native(
model_dir=model_dir,
out_dir=tflite_dir,
model_type=model_type,
quantize=q,
prefill=args.prefill_seq_len,
kvcache=args.kv_cache_max_len,
)
if tflite_file:
log(f"Strategy 1 success: {tflite_file}")
strategy1_succeeded = True
except Exception as e:
warn(f"Strategy 1 failed: {e}")
traceback.print_exc()
if not tflite_file:
try:
tflite_file = strategy2_generic(
model_dir=model_dir,
out_dir=tflite_dir,
prefill=args.prefill_seq_len,
quantize=q,
)
if tflite_file:
log(f"Strategy 2 success: {tflite_file}")
warn("Generic TFLite may not have MediaPipe LLM prefill/decode signatures.")
except Exception as e:
warn(f"Strategy 2 failed: {e}")
traceback.print_exc()
# Strategy 3: post-TFLite quantization — only when Strategy 2 was used (Strategy 1
# already applies quantization natively via the converter).
if tflite_file and not strategy1_succeeded and q != "none":
try:
quantized = strategy3_post_tflite_quantize(tflite_file, tflite_dir, q)
if quantized:
tflite_file = quantized
log(f"Strategy 3 success: {tflite_file}")
else:
warn("Strategy 3 skipped; bundling unquantized model")
except Exception as e:
warn(f"Strategy 3 failed: {e}")
traceback.print_exc()
if not tflite_file:
die("All conversion strategies failed")
tokenizer_model = ensure_tokenizer_model(model_dir)
bundle_task(tflite_file, tokenizer_model, task_file)
log(f"DONE: {task_file}")
if task_file.exists():
log(f"Size: {task_file.stat().st_size / (1024 * 1024):.2f} MB")
if __name__ == "__main__":
main()