| |
| import argparse |
| import importlib |
| import json |
| import os |
| import shlex |
| import signal |
| import subprocess |
| import sys |
| from datetime import datetime, timezone |
| from pathlib import Path |
| from typing import Optional |
|
|
| |
| SUPPORTED_BY_CONVERTER = {"none", "dynamic_int8", "float16", "int8", "int4"} |
|
|
|
|
| def run( |
| cmd, |
| cwd=None, |
| check=True, |
| env=None, |
| log_file: Optional[Path] = None, |
| timeout_sec: int = 0, |
| ): |
| print(f"[CMD] {' '.join(shlex.quote(c) for c in cmd)}", flush=True) |
|
|
| p = subprocess.Popen( |
| cmd, |
| cwd=cwd, |
| env=env, |
| stdout=subprocess.PIPE, |
| stderr=subprocess.STDOUT, |
| text=True, |
| bufsize=1, |
| preexec_fn=os.setsid, |
| ) |
|
|
| lines = [] |
| timed_out = False |
|
|
| try: |
| assert p.stdout is not None |
| for line in p.stdout: |
| print(line, end="", flush=True) |
| lines.append(line) |
|
|
| if timeout_sec > 0: |
| rc = p.wait(timeout=timeout_sec) |
| else: |
| rc = p.wait() |
|
|
| except subprocess.TimeoutExpired: |
| timed_out = True |
| print(f"\n[!] Timeout after {timeout_sec}s. Terminating process group...", flush=True) |
| try: |
| os.killpg(os.getpgid(p.pid), signal.SIGTERM) |
| except Exception: |
| pass |
| try: |
| p.wait(timeout=8) |
| except Exception: |
| try: |
| os.killpg(os.getpgid(p.pid), signal.SIGKILL) |
| except Exception: |
| pass |
| rc = 124 |
|
|
| except KeyboardInterrupt: |
| print("\n[!] KeyboardInterrupt received. Terminating process group...", flush=True) |
| try: |
| os.killpg(os.getpgid(p.pid), signal.SIGTERM) |
| except Exception: |
| pass |
| try: |
| p.wait(timeout=8) |
| except Exception: |
| try: |
| os.killpg(os.getpgid(p.pid), signal.SIGKILL) |
| except Exception: |
| pass |
| raise |
|
|
| out = "".join(lines) |
| if log_file: |
| log_file.parent.mkdir(parents=True, exist_ok=True) |
| log_file.write_text(out, encoding="utf-8") |
|
|
| if check and rc != 0: |
| msg = f"Command failed ({rc})" |
| if timed_out: |
| msg += " [timeout]" |
| msg += f": {' '.join(cmd)}" |
| raise RuntimeError(msg) |
|
|
| return rc, out, timed_out |
|
|
|
|
| def ensure_hf_repo(repo_id: str, private: bool): |
| cmd = ["hf", "repo", "create", repo_id, "--type", "model"] |
| if private: |
| cmd.append("--private") |
| else: |
| cmd.append("--public") |
|
|
| rc, out, _ = run(cmd, check=False) |
| low = out.lower() |
| if rc != 0 and "already exists" not in low: |
| raise RuntimeError(f"Failed creating repo {repo_id}") |
|
|
|
|
| def hf_upload(repo_id: str, local_path: Path, path_in_repo: str): |
| cmd = [ |
| "hf", |
| "upload", |
| repo_id, |
| str(local_path), |
| path_in_repo, |
| "--repo-type", |
| "model", |
| ] |
| run(cmd, check=True) |
|
|
|
|
| def normalize_quant(q: str) -> str: |
| q = q.strip().lower() |
| aliases = { |
| "fp16": "float16", |
| "f16": "float16", |
| "i8": "int8", |
| "i4": "int4", |
| "q8": "int8", |
| "q4": "int4", |
| "fp32": "none", |
| "off": "none", |
| "no": "none", |
| } |
| return aliases.get(q, q) |
|
|
|
|
| def detect_native_quant_possible(model_dir: Path): |
| """ |
| Heuristic for your environment: |
| - For gemma3, native quant is likely usable only if build_model_4b exists. |
| """ |
| cfg = model_dir / "config.json" |
| if not cfg.exists(): |
| return False, "config.json missing" |
|
|
| try: |
| j = json.loads(cfg.read_text()) |
| model_type = (j.get("model_type") or "").lower() |
| except Exception as e: |
| return False, f"config parse failed: {e}" |
|
|
| if model_type == "gemma3": |
| try: |
| mod = importlib.import_module("litert_torch.generative.examples.gemma3.gemma3") |
| available = {n for n in dir(mod) if n.startswith("build_model")} |
| if "build_model_4b" in available: |
| return True, "gemma3 build_model_4b present" |
| return False, f"gemma3 4b builder missing; available={sorted(available)}" |
| except Exception as e: |
| return False, f"cannot import gemma3 converter module: {e}" |
|
|
| return True, f"model_type={model_type or 'unknown'}" |
|
|
|
|
| def plan_quants(requested_quants, native_ok: bool): |
| """ |
| Build plan for requested quantization modes. |
| |
| Strategy 3 (post-TFLite quantization) gives int4/int8/dynamic_int8 genuinely |
| different output sizes, so they are each built independently. |
| |
| The only deduplication is: if 'none' is requested alongside int4/int8/dynamic_int8, |
| the float32 TFLite produced by those builds also satisfies 'none', so we skip a |
| redundant bare 'none' build. |
| """ |
| requested = [] |
| unsupported = [] |
|
|
| for q in requested_quants: |
| if q not in SUPPORTED_BY_CONVERTER: |
| unsupported.append(q) |
| else: |
| requested.append(q) |
|
|
| |
| build_plan = requested[:] |
| alias_map = {q: q for q in requested} |
| mode = "native_quant_available" if native_ok else "strategy3_post_tflite" |
|
|
| return build_plan, alias_map, unsupported, mode |
|
|
|
|
| def main(): |
| ap = argparse.ArgumentParser( |
| description="Run multi-quant conversion+bundle and upload successful artifacts to HF." |
| ) |
| ap.add_argument("--converter-script", default="/home/ubuntu/convert_translategemma_android.py") |
| ap.add_argument("--model-id", default="google/translategemma-4b-it") |
| ap.add_argument("--model-dir", default="/home/ubuntu/translategemma-4b-it") |
| ap.add_argument("--tflite-root", default="/home/ubuntu/tflite_output") |
| ap.add_argument("--output-dir", default="/home/ubuntu/output") |
| ap.add_argument("--log-dir", default="/home/ubuntu/logs") |
| ap.add_argument("--prefill", type=int, default=1024) |
| ap.add_argument("--kvcache", type=int, default=1024) |
| ap.add_argument("--timeout-sec", type=int, default=0, help="Per-quant timeout. 0 = no timeout.") |
|
|
| ap.add_argument( |
| "--quants", |
| default="int4,int8,fp8,fp16,dynamic_int8", |
| help="Comma-separated requested modes", |
| ) |
| ap.add_argument( |
| "--repo-id", |
| default="barakplasma/translategemma-4b-it-android-task-quantized", |
| ) |
|
|
| vis = ap.add_mutually_exclusive_group() |
| vis.add_argument("--private", action="store_true", default=True) |
| vis.add_argument("--public", action="store_true") |
|
|
| ap.add_argument("--no-upload", action="store_true") |
| args = ap.parse_args() |
|
|
| converter_script = Path(args.converter_script) |
| if not converter_script.exists(): |
| print(f"[x] converter script not found: {converter_script}", file=sys.stderr) |
| sys.exit(1) |
|
|
| model_dir = Path(args.model_dir) |
| tflite_root = Path(args.tflite_root) |
| output_dir = Path(args.output_dir) |
| log_dir = Path(args.log_dir) |
|
|
| tflite_root.mkdir(parents=True, exist_ok=True) |
| output_dir.mkdir(parents=True, exist_ok=True) |
| log_dir.mkdir(parents=True, exist_ok=True) |
|
|
| |
| quant_list = [normalize_quant(x) for x in args.quants.split(",") if x.strip()] |
| seen = set() |
| quant_list = [q for q in quant_list if not (q in seen or seen.add(q))] |
|
|
| native_ok, native_reason = detect_native_quant_possible(model_dir) |
| print(f"[+] native quant capability: {native_ok} ({native_reason})") |
|
|
| build_plan, alias_map, unsupported, plan_mode = plan_quants(quant_list, native_ok=native_ok) |
| print(f"[+] plan mode: {plan_mode}") |
| print(f"[+] requested quants: {quant_list}") |
| print(f"[+] build plan: {build_plan}") |
| if unsupported: |
| print(f"[!] unsupported (skipped): {unsupported}") |
|
|
| built = {} |
| results = [] |
|
|
| try: |
| for q in build_plan: |
| print(f"\n=== BUILD QUANT: {q} ===", flush=True) |
| q_tflite_dir = tflite_root / q |
| q_tflite_dir.mkdir(parents=True, exist_ok=True) |
|
|
| task_file = output_dir / f"translategemma-4b-it-{q}.task" |
| log_file = log_dir / f"convert_{q}.log" |
|
|
| |
| base_tflite = tflite_root / "none" / "translategemma-4b-it-generic-none.tflite" |
| if q not in ("none", "float16") and base_tflite.exists(): |
| print(f"[+] Reusing existing float32 TFLite for {q} (Strategy 3 will quantize)", flush=True) |
| cmd = [ |
| sys.executable, |
| str(converter_script), |
| "--bundle-only", |
| "--existing-tflite", |
| str(base_tflite), |
| "--tflite-dir", |
| str(q_tflite_dir), |
| "--output-dir", |
| str(output_dir), |
| "--task-file", |
| str(task_file), |
| "--quantize", |
| q, |
| "--model-dir", |
| str(model_dir), |
| ] |
| else: |
| cmd = [ |
| sys.executable, |
| str(converter_script), |
| "--model-id", |
| args.model_id, |
| "--model-dir", |
| str(model_dir), |
| "--tflite-dir", |
| str(q_tflite_dir), |
| "--output-dir", |
| str(output_dir), |
| "--task-file", |
| str(task_file), |
| "--quantize", |
| q, |
| "--prefill", |
| str(args.prefill), |
| "--kvcache", |
| str(args.kvcache), |
| "--allow-no-token", |
| ] |
|
|
| rc, _, timed_out = run( |
| cmd, |
| check=False, |
| log_file=log_file, |
| timeout_sec=args.timeout_sec, |
| ) |
|
|
| tflites = sorted(q_tflite_dir.glob("*.tflite")) |
| tflite_file = tflites[-1] if tflites else None |
| ok = (rc == 0) and task_file.exists() |
|
|
| built[q] = { |
| "quant": q, |
| "ok": ok, |
| "rc": rc, |
| "timed_out": timed_out, |
| "task": str(task_file) if task_file.exists() else "", |
| "task_size": task_file.stat().st_size if task_file.exists() else 0, |
| "tflite": str(tflite_file) if tflite_file and tflite_file.exists() else "", |
| "log": str(log_file), |
| } |
|
|
| except KeyboardInterrupt: |
| print("\n[!] Stopped by user (Ctrl+C). Partial results will be saved.", flush=True) |
|
|
| |
| for q in quant_list: |
| if q in unsupported: |
| results.append( |
| { |
| "quant": q, |
| "ok": False, |
| "rc": 2, |
| "timed_out": False, |
| "task": "", |
| "task_size": 0, |
| "tflite": "", |
| "log": "", |
| "status": "unsupported", |
| "alias_of": "", |
| } |
| ) |
| continue |
|
|
| bq = alias_map.get(q, q) |
| b = built.get(bq) |
| if not b: |
| results.append( |
| { |
| "quant": q, |
| "ok": False, |
| "rc": 130, |
| "timed_out": False, |
| "task": "", |
| "task_size": 0, |
| "tflite": "", |
| "log": "", |
| "status": "not_built", |
| "alias_of": bq, |
| } |
| ) |
| continue |
|
|
| status = "built" if q == bq else f"aliased_to_{bq}" |
| results.append( |
| { |
| **b, |
| "quant": q, |
| "status": status, |
| "alias_of": bq if q != bq else "", |
| } |
| ) |
|
|
| |
| summary = { |
| "timestamp_utc": datetime.now(timezone.utc).isoformat(), |
| "native_quant_capability": native_ok, |
| "native_quant_reason": native_reason, |
| "plan_mode": plan_mode, |
| "requested_quants": quant_list, |
| "build_plan": build_plan, |
| "unsupported_quants": unsupported, |
| "results": results, |
| } |
|
|
| summary_json = output_dir / "quantization_summary.json" |
| summary_json.write_text(json.dumps(summary, indent=2), encoding="utf-8") |
| print(f"\n[+] Wrote summary: {summary_json}") |
|
|
| |
| readme = output_dir / "README.md" |
| lines = [] |
| lines.append("---") |
| lines.append("license: other") |
| lines.append("library_name: mediapipe") |
| lines.append("pipeline_tag: text-generation") |
| lines.append("---\n") |
| lines.append("# TranslateGemma 4B IT - Quantized Android Task Bundles\n") |
| lines.append(f"Generated: `{datetime.now(timezone.utc).isoformat()}`\n") |
| lines.append(f"- Native quant capability: `{native_ok}`") |
| lines.append(f"- Reason: `{native_reason}`") |
| lines.append(f"- Plan mode: `{plan_mode}`\n") |
| lines.append("| Requested quant | Status | Built from | Task file | Size (bytes) |") |
| lines.append("|---|---|---|---|---|") |
| for r in results: |
| if r.get("status") == "unsupported": |
| status = "⏭️ unsupported by converter" |
| built_from = "-" |
| elif str(r.get("status", "")).startswith("aliased_to_"): |
| status = "↪️ aliased" |
| built_from = f"`{r.get('alias_of','-')}`" |
| else: |
| if r.get("timed_out"): |
| status = "⏱️ timeout" |
| else: |
| status = "✅ success" if r.get("ok") else f"❌ failed (rc={r.get('rc')})" |
| built_from = "`self`" |
|
|
| task_name = Path(r["task"]).name if r.get("task") else "-" |
| lines.append( |
| f"| `{r['quant']}` | {status} | {built_from} | `{task_name}` | `{r.get('task_size',0)}` |" |
| ) |
|
|
| lines.append("\n## Notes") |
| lines.append("- Aliased entries are not rebuilt; they point to an equivalent built variant.") |
| lines.append("- `fp8` is often unsupported in current converter/runtime stacks.") |
| lines.append("- Verify on-device compatibility before public release.") |
|
|
| readme.write_text("\n".join(lines), encoding="utf-8") |
| print(f"[+] Wrote README: {readme}") |
|
|
| if args.no_upload: |
| print("[!] --no-upload set. Done.") |
| return |
|
|
| private = False if args.public else True |
| ensure_hf_repo(args.repo_id, private=private) |
|
|
| hf_upload(args.repo_id, readme, "README.md") |
| hf_upload(args.repo_id, summary_json, "quantization_summary.json") |
|
|
| uploaded = set() |
|
|
| |
| for q, b in built.items(): |
| if b.get("log"): |
| lp = Path(b["log"]) |
| if lp.exists() and str(lp) not in uploaded: |
| hf_upload(args.repo_id, lp, f"logs/{lp.name}") |
| uploaded.add(str(lp)) |
|
|
| if b.get("tflite"): |
| tp = Path(b["tflite"]) |
| if tp.exists() and str(tp) not in uploaded: |
| hf_upload(args.repo_id, tp, f"artifacts/{q}/{tp.name}") |
| uploaded.add(str(tp)) |
|
|
| if b.get("task"): |
| tk = Path(b["task"]) |
| if tk.exists() and str(tk) not in uploaded: |
| hf_upload(args.repo_id, tk, f"artifacts/{q}/{tk.name}") |
| uploaded.add(str(tk)) |
|
|
| print(f"\n[+] Upload complete: https://huggingface.co/{args.repo_id}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |