File size: 21,633 Bytes
9c09026
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7f4c0e9
9c09026
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
#!/usr/bin/env python3
import argparse
import importlib
import inspect
import json
import os
import subprocess
import sys
import traceback
from pathlib import Path

os.environ.setdefault("TRANSFORMERS_NO_TORCHVISION", "1")
os.environ.setdefault("PYTHONUNBUFFERED", "1")


def log(msg): print(f"[+] {msg}", flush=True)
def warn(msg): print(f"[!] {msg}", flush=True)
def die(msg):
    print(f"[x] {msg}", file=sys.stderr, flush=True)
    sys.exit(1)


SUPPORTED_QUANT = {
    "none",
    "dynamic_int8",
    "float16",
    "int8",
    "int4",
}


def normalize_quantize(q: str) -> str:
    q = (q or "dynamic_int8").strip().lower()
    aliases = {
        "fp32": "none",
        "no": "none",
        "off": "none",
        "fp16": "float16",
        "f16": "float16",
        "i8": "int8",
        "q8": "int8",
        "i4": "int4",
        "q4": "int4",
    }
    q = aliases.get(q, q)
    if q not in SUPPORTED_QUANT:
        die(f"Unsupported --quantize '{q}'. Supported: {sorted(SUPPORTED_QUANT)}")
    return q


def load_config(model_dir: Path):
    cfg_path = model_dir / "config.json"
    if not cfg_path.exists():
        die(f"Missing config: {cfg_path}")
    cfg = json.loads(cfg_path.read_text())
    text_cfg = cfg.get("text_config", {})
    return cfg, text_cfg


def inspect_arch(model_dir: Path):
    cfg, text_cfg = load_config(model_dir)
    info = {
        "model_type": cfg.get("model_type", "unknown"),
        "architecture": (cfg.get("architectures") or ["unknown"])[0],
        "vocab_size": cfg.get("vocab_size", text_cfg.get("vocab_size", 262144)),
    }
    log(f"ARCH: {info}")
    return info


def ensure_model_downloaded(model_id: str, model_dir: Path, hf_token: str):
    if (model_dir / "config.json").exists():
        log(f"Using existing model dir: {model_dir}")
        return

    log(f"Downloading {model_id} -> {model_dir}")
    try:
        from huggingface_hub import snapshot_download
        snapshot_download(
            repo_id=model_id,
            local_dir=str(model_dir),
            token=hf_token if hf_token else None,
            local_dir_use_symlinks=False,
        )
    except Exception as e:
        die(f"Model download failed: {e}")


def try_builders(mod, builder_names, model_dir: Path):
    for fn_name in builder_names:
        fn = getattr(mod, fn_name, None)
        if fn is None:
            continue
        log(f"Trying {mod.__name__}.{fn_name} ...")
        try:
            m = fn(str(model_dir))
            if m is None:
                warn("  returned None")
                continue
            if isinstance(m, (tuple, list)) and len(m) > 0:
                m = m[0]
            if not hasattr(m, "eval"):
                warn(f"  unsupported return type: {type(m)}")
                continue
            m.eval()
            log(f"  success with {fn_name}")
            return m
        except Exception as e:
            warn(f"  failed: {e}")
    return None


def build_translategemma_4b(checkpoint_path: str):
    """
    Custom builder for TranslateGemma 4B IT (Gemma3 multimodal decoder).
    Strips 'language_model.' prefix from safetensors keys so the standard
    TENSOR_NAMES_SEP_QKV mapping works.

    Architecture (from config.json / verified weight shapes):
      34 layers, embedding_dim=2560, 8 heads, head_dim=256, 4 KV heads,
      intermediate=10240, sliding_window=1024, global every 6th layer.
    """
    import safetensors.torch as st_lib
    import json as json_lib
    from litert_torch.generative.utilities import model_builder, loader as loading_utils
    from litert_torch.generative.layers import kv_cache as kv_utils
    from litert_torch.generative.examples.gemma3 import decoder as gemma3_decoder
    import litert_torch.generative.layers.model_config as cfg_mod

    norm = cfg_mod.NormalizationConfig(
        type=cfg_mod.NormalizationType.RMS_NORM, epsilon=1e-6, zero_centered=True
    )
    ff = cfg_mod.FeedForwardConfig(
        type=cfg_mod.FeedForwardType.GATED,
        activation=cfg_mod.ActivationConfig(cfg_mod.ActivationType.GELU_TANH),
        intermediate_size=10240,
        pre_ff_norm_config=norm, post_ff_norm_config=norm,
    )

    def blk(idx):
        attn = cfg_mod.AttentionConfig(
            num_heads=8, head_dim=256, num_query_groups=4,
            rotary_base=1_000_000 if (idx + 1) % 6 == 0 else 10_000,
            rotary_percentage=1.0, qkv_transpose_before_split=True,
            qkv_fused_interleaved=False,
            query_norm_config=norm, key_norm_config=norm, logit_softcap=None,
            sliding_window_size=1024,
            attn_type=cfg_mod.AttentionType.GLOBAL if (idx + 1) % 6 == 0
            else cfg_mod.AttentionType.LOCAL_SLIDING,
        )
        return cfg_mod.TransformerBlockConfig(
            attn_config=attn, ff_config=ff,
            pre_attention_norm_config=norm, post_attention_norm_config=norm,
        )

    model_config = cfg_mod.ModelConfig(
        vocab_size=262208, num_layers=34, max_seq_len=8192,
        embedding_dim=2560, embedding_scale=2560 ** 0.5,
        block_configs=[blk(i) for i in range(34)],
        final_norm_config=norm, lm_head_use_bias=False, final_logit_softcap=None,
    )

    tensor_names = loading_utils.ModelLoader.TensorNames(
        ff_up_proj="model.layers.{}.mlp.up_proj",
        ff_down_proj="model.layers.{}.mlp.down_proj",
        ff_gate_proj="model.layers.{}.mlp.gate_proj",
        attn_query_proj="model.layers.{}.self_attn.q_proj",
        attn_key_proj="model.layers.{}.self_attn.k_proj",
        attn_value_proj="model.layers.{}.self_attn.v_proj",
        attn_output_proj="model.layers.{}.self_attn.o_proj",
        attn_query_norm="model.layers.{}.self_attn.q_norm",
        attn_key_norm="model.layers.{}.self_attn.k_norm",
        pre_attn_norm="model.layers.{}.input_layernorm",
        post_attn_norm="model.layers.{}.post_attention_layernorm",
        pre_ff_norm="model.layers.{}.pre_feedforward_layernorm",
        post_ff_norm="model.layers.{}.post_feedforward_layernorm",
        embedding="model.embed_tokens",
        final_norm="model.norm",
        lm_head=None,
    )

    def custom_loader(path: str):
        idx = json_lib.loads((Path(path) / "model.safetensors.index.json").read_text())
        out = {}
        for fname in set(idx["weight_map"].values()):
            for k, v in st_lib.load_file(str(Path(path) / fname)).items():
                out[k[len("language_model."):] if k.startswith("language_model.") else k] = v
        return out

    return model_builder.build_decoder_only_model(
        checkpoint_path=checkpoint_path,
        config=model_config,
        tensor_names=tensor_names,
        model_class=gemma3_decoder.Decoder,
        custom_loader=custom_loader,
    )


def strategy1_litert_native(model_dir: Path, out_dir: Path, model_type: str, quantize: str, prefill: int, kvcache: int):
    """
    Native LiteRT conversion with explicit quantization support.
    """
    log("Strategy 1: litert-torch native")

    from litert_torch.generative.utilities import converter
    from litert_torch.generative.utilities.export_config import ExportConfig
    from litert_torch.generative.layers import kv_cache

    export_config = ExportConfig()
    export_config.kvcache_layout = kv_cache.KV_LAYOUT_TRANSPOSED
    export_config.mask_as_input = True

    # Map our quantize flags to converter's QuantizationName values
    QUANT_MAP = {
        "none": "none",
        "dynamic_int8": "dynamic_int8",
        "int8": "weight_only_int8",
        "float16": "fp16",
        "int4": "dynamic_int4_block128",
    }
    quant_for_converter = QUANT_MAP.get(quantize)
    if quant_for_converter is None:
        warn(f"No converter mapping for '{quantize}', falling back to Strategy 2")
        return None

    model = None

    if model_type == "gemma3":
        # Try custom 4B builder first (handles TranslateGemma 4B multimodal weight prefix)
        log("Trying custom build_translategemma_4b ...")
        try:
            model = build_translategemma_4b(str(model_dir))
            if model is not None:
                model.eval()
                log("  build_translategemma_4b success")
        except Exception as e:
            warn(f"  build_translategemma_4b failed: {e}")
            model = None

        if model is None:
            mod = importlib.import_module("litert_torch.generative.examples.gemma3.gemma3")
            available = [n for n in dir(mod) if n.startswith("build_model")]
            log(f"Gemma3 builders available: {available}")
            preferred = ["build_model_4b", "build_model_2b", "build_model_1b", "build_model_270m", "build_model"]
            ordered = [n for n in preferred if n in available] + [n for n in available if n not in preferred]
            model = try_builders(mod, ordered, model_dir)

    elif model_type in ("gemma", "gemma2"):
        mod = importlib.import_module("litert_torch.generative.examples.gemma2.gemma2")
        available = [n for n in dir(mod) if n.startswith("build_model")]
        log(f"Gemma2 builders available: {available}")
        preferred = ["build_model_4b", "build_model_2b", "build_model"]
        ordered = [n for n in preferred if n in available] + [n for n in available if n not in preferred]
        model = try_builders(mod, ordered, model_dir)

    else:
        warn(f"Model type '{model_type}' not handled by native strategy")
        return None

    if model is None:
        warn("Strategy 1 did not find a compatible builder")
        return None

    converter.convert_to_tflite(
        model,
        output_path=str(out_dir),
        output_name_prefix=f"translategemma-4b-it-{quantize}",
        prefill_seq_len=prefill,
        kv_cache_max_len=kvcache,
        quantize=quant_for_converter,
        export_config=export_config,
    )

    produced = sorted(out_dir.glob(f"*{quantize}*.tflite")) or sorted(out_dir.glob("*.tflite"))
    return produced[0] if produced else None


def strategy2_generic(model_dir: Path, out_dir: Path, prefill: int, quantize: str):
    """
    Generic fallback conversion (logits-only). Always exports float32; use
    strategy3_post_tflite_quantize() afterwards for real int4/int8 compression.
    """
    log("Strategy 2: ai_edge_torch generic (wrapped logits-only)")

    import torch
    try:
        import litert_torch as ai_edge_torch
    except Exception:
        import ai_edge_torch  # deprecated fallback

    from transformers import AutoConfig, AutoModelForCausalLM

    dtype = torch.float32
    if quantize == "float16":
        dtype = torch.float16

    cfg = AutoConfig.from_pretrained(str(model_dir), trust_remote_code=True)
    vocab = getattr(cfg, "vocab_size", None)
    if vocab is None and hasattr(cfg, "text_config"):
        vocab = getattr(cfg.text_config, "vocab_size", None)
    vocab = int(vocab or 262144)

    log(f"Loading HF model on CPU with dtype={dtype} ...")
    base_model = AutoModelForCausalLM.from_pretrained(
        str(model_dir),
        trust_remote_code=True,
        torch_dtype=dtype,
    )
    base_model.eval()

    # TFLite embedding_lookup does not support f16 weights — keep embedding in float32
    if dtype == torch.float16:
        log("Casting embedding and lm_head to float32 for TFLite compatibility")
        if hasattr(base_model, "model") and hasattr(base_model.model, "embed_tokens"):
            base_model.model.embed_tokens = base_model.model.embed_tokens.to(torch.float32)
        if hasattr(base_model, "lm_head"):
            base_model.lm_head = base_model.lm_head.to(torch.float32)

    if hasattr(base_model, "config") and hasattr(base_model.config, "use_cache"):
        base_model.config.use_cache = False

    class LogitsOnlyWrapper(torch.nn.Module):
        def __init__(self, model):
            super().__init__()
            self.model = model

        def forward(self, input_ids):
            out = self.model(
                input_ids=input_ids,
                use_cache=False,
                return_dict=False,
            )
            logits = out[0] if isinstance(out, (tuple, list)) else out.logits
            return logits

    wrapped = LogitsOnlyWrapper(base_model).eval()
    sample_ids = torch.randint(0, vocab, (1, min(prefill, 128)), dtype=torch.int64)

    # Always export float32 base; int4/int8 quantization applied post-export via Strategy 3
    out_file = out_dir / f"translategemma-4b-it-generic-none.tflite"
    edge_model = ai_edge_torch.convert(wrapped, (sample_ids,))
    edge_model.export(str(out_file))

    return out_file if out_file.exists() else None


def strategy3_post_tflite_quantize(tflite_in: Path, out_dir: Path, quantize: str):
    """
    Post-conversion weight quantization applied directly to a TFLite flatbuffer
    using ai_edge_quantizer (bundled with litert_torch).

    Supported modes and their recipes (per get_supported_layer_schemes()):
      int4         -> INT4 DYNAMIC_RANGE BLOCKWISE_128  (~2 GB for 4B model)
      int8         -> INT8 WEIGHT_ONLY   CHANNELWISE     (~4 GB)
      dynamic_int8 -> INT8 DYNAMIC_RANGE CHANNELWISE     (~4 GB)
      float16      -> FP16 WEIGHT_ONLY   FLOAT_CAST      (~8 GB)
    """
    log(f"Strategy 3: post-TFLite quantization ({quantize}) on {tflite_in.name}")

    from litert_torch.generative.quantize import quant_attrs as qa, quant_recipe, quant_recipe_utils
    from litert_torch.quantize import translate_recipe

    if quantize == "int4":
        layer = quant_recipe_utils.create_layer_quant_dynamic(qa.Dtype.INT4, qa.Granularity.BLOCKWISE_128)
    elif quantize == "int8":
        layer = quant_recipe_utils.create_layer_quant_weight_only(qa.Dtype.INT8, qa.Granularity.CHANNELWISE)
    elif quantize == "dynamic_int8":
        layer = quant_recipe_utils.create_layer_quant_dynamic(qa.Dtype.INT8, qa.Granularity.CHANNELWISE)
    elif quantize == "float16":
        layer = quant_recipe_utils.create_layer_quant_fp16()
    else:
        warn(f"Strategy 3: no post-TFLite recipe for '{quantize}', skipping")
        return None

    gen_recipe = quant_recipe.GenerativeQuantRecipe(default=layer)
    ai_recipe = translate_recipe.translate_to_ai_edge_recipe(gen_recipe)

    model_bytes = tflite_in.read_bytes()
    log(f"  Input: {len(model_bytes) / 1024**3:.2f} GB — quantizing ...")
    quantized_bytes = translate_recipe.quantize_model(model_bytes, ai_recipe)

    out_file = out_dir / f"translategemma-4b-it-{quantize}.tflite"
    out_file.write_bytes(quantized_bytes)
    log(f"  Output: {len(quantized_bytes) / 1024**3:.2f} GB -> {out_file}")
    return out_file


def ensure_tokenizer_model(model_dir: Path):
    tok_model = model_dir / "tokenizer.model"
    if tok_model.exists():
        return tok_model

    tok_json = model_dir / "tokenizer.json"
    if not tok_json.exists():
        die(f"Missing tokenizer files: neither {tok_model} nor {tok_json} exists")

    log("Converting tokenizer.json -> tokenizer.model")
    cmd = [
        sys.executable,
        "-m",
        "litert_torch.generative.tools.tokenizer_to_sentencepiece",
        f"--checkpoint={model_dir}",
        f"--output_path={tok_model}",
    ]
    res = subprocess.run(cmd, text=True, capture_output=True)
    if res.stdout:
        print(res.stdout, flush=True)
    if res.returncode != 0:
        if res.stderr:
            print(res.stderr, file=sys.stderr, flush=True)
        die("Tokenizer conversion failed")

    if not tok_model.exists():
        die("Tokenizer conversion reported success but tokenizer.model is missing")

    return tok_model


def bundle_task(tflite_file: Path, tokenizer_model: Path, task_file: Path):
    log(f"Bundling .task -> {task_file}")
    from mediapipe.tasks.python.genai import bundler

    task_file.parent.mkdir(parents=True, exist_ok=True)

    sig = inspect.signature(bundler.BundleConfig)
    params = sig.parameters
    log(f"BundleConfig params: {list(params.keys())}")

    kwargs = {
        "tflite_model": str(tflite_file),
        "tokenizer_model": str(tokenizer_model),
        "output_filename": str(task_file),
    }

    if "start_token" in params:
        kwargs["start_token"] = "<bos>"
    elif "start_tokens" in params:
        kwargs["start_tokens"] = ["<bos>"]

    if "stop_tokens" in params:
        kwargs["stop_tokens"] = ["<eos>"]

    if "prompt_prefix" in params:
        kwargs["prompt_prefix"] = ""
    if "prompt_suffix" in params:
        kwargs["prompt_suffix"] = ""

    kwargs = {k: v for k, v in kwargs.items() if k in params}
    cfg = bundler.BundleConfig(**kwargs)
    bundler.create_bundle(cfg)


def main():
    ap = argparse.ArgumentParser(description="TranslateGemma -> Android .task converter")

    ap.add_argument("--model-id", default="google/translategemma-4b-it")
    ap.add_argument("--model-dir", default="./translategemma-4b-it")
    ap.add_argument("--tflite-dir", default="./tflite_output")
    ap.add_argument("--output-dir", default="./output")
    ap.add_argument("--task-file", default="./output/translategemma-4b-it-android.task")

    ap.add_argument("--quantize", default="dynamic_int8", help=f"One of: {sorted(SUPPORTED_QUANT)}")

    ap.add_argument("--prefill-seq-len", "--prefill", dest="prefill_seq_len", type=int, default=1024)
    ap.add_argument("--kv-cache-max-len", "--kvcache", dest="kv_cache_max_len", type=int, default=1024)

    ap.add_argument("--skip-strategy1", action="store_true")
    ap.add_argument("--bundle-only", action="store_true", help="Skip conversion; only bundle existing TFLite")
    ap.add_argument("--existing-tflite", default="", help="Path to an existing .tflite to bundle")
    ap.add_argument("--allow-no-token", action="store_true", help="Allow model download from public repo without HF_TOKEN")

    args = ap.parse_args()
    q = normalize_quantize(args.quantize)

    hf_token = os.environ.get("HF_TOKEN", "").strip()
    if not hf_token and not args.allow_no_token:
        warn("HF_TOKEN is not set. If model is public, you can pass --allow-no-token.")

    model_dir = Path(args.model_dir)
    tflite_dir = Path(args.tflite_dir)
    output_dir = Path(args.output_dir)
    task_file = Path(args.task_file)

    tflite_dir.mkdir(parents=True, exist_ok=True)
    output_dir.mkdir(parents=True, exist_ok=True)

    tflite_file = None

    if args.bundle_only:
        if not args.existing_tflite:
            die("--bundle-only requires --existing-tflite /path/to/model.tflite")
        tflite_file = Path(args.existing_tflite)
        if not tflite_file.exists():
            die(f"Existing tflite not found: {tflite_file}")
        log(f"Bundle-only mode using: {tflite_file}")
        # Apply post-TFLite quantization if requested
        if q != "none":
            try:
                quantized = strategy3_post_tflite_quantize(tflite_file, tflite_dir, q)
                if quantized:
                    tflite_file = quantized
                    log(f"Strategy 3 success: {tflite_file}")
                else:
                    warn("Strategy 3 skipped; bundling input as-is")
            except Exception as e:
                warn(f"Strategy 3 failed: {e}")
                traceback.print_exc()
    else:
        ensure_model_downloaded(args.model_id, model_dir, hf_token if hf_token else "")
        arch = inspect_arch(model_dir)
        model_type = arch["model_type"]

        strategy1_succeeded = False
        if not args.skip_strategy1:
            try:
                tflite_file = strategy1_litert_native(
                    model_dir=model_dir,
                    out_dir=tflite_dir,
                    model_type=model_type,
                    quantize=q,
                    prefill=args.prefill_seq_len,
                    kvcache=args.kv_cache_max_len,
                )
                if tflite_file:
                    log(f"Strategy 1 success: {tflite_file}")
                    strategy1_succeeded = True
            except Exception as e:
                warn(f"Strategy 1 failed: {e}")
                traceback.print_exc()

        if not tflite_file:
            try:
                tflite_file = strategy2_generic(
                    model_dir=model_dir,
                    out_dir=tflite_dir,
                    prefill=args.prefill_seq_len,
                    quantize=q,
                )
                if tflite_file:
                    log(f"Strategy 2 success: {tflite_file}")
                    warn("Generic TFLite may not have MediaPipe LLM prefill/decode signatures.")
            except Exception as e:
                warn(f"Strategy 2 failed: {e}")
                traceback.print_exc()

        # Strategy 3: post-TFLite quantization — only when Strategy 2 was used (Strategy 1
        # already applies quantization natively via the converter).
        if tflite_file and not strategy1_succeeded and q != "none":
            try:
                quantized = strategy3_post_tflite_quantize(tflite_file, tflite_dir, q)
                if quantized:
                    tflite_file = quantized
                    log(f"Strategy 3 success: {tflite_file}")
                else:
                    warn("Strategy 3 skipped; bundling unquantized model")
            except Exception as e:
                warn(f"Strategy 3 failed: {e}")
                traceback.print_exc()

        if not tflite_file:
            die("All conversion strategies failed")

    tokenizer_model = ensure_tokenizer_model(model_dir)
    bundle_task(tflite_file, tokenizer_model, task_file)

    log(f"DONE: {task_file}")
    if task_file.exists():
        log(f"Size: {task_file.stat().st_size / (1024 * 1024):.2f} MB")


if __name__ == "__main__":
    main()