| """ |
| Run text generation inference on the exported Qwen3.5-0.8B TFLite model. |
| |
| Usage: |
| python inference_tflite.py --model_path output/qwen35_0.8b/qwen35_q8_ekv2048.tflite |
| python inference_tflite.py --prompt "Explain gravity" --max_new_tokens 100 |
| """ |
|
|
| import argparse |
| import glob |
| import logging |
| import time |
|
|
| import numpy as np |
| import transformers |
| from ai_edge_litert import interpreter as tfl_interpreter |
|
|
| logging.basicConfig( |
| level=logging.INFO, |
| format="%(asctime)s [%(levelname)s] %(message)s", |
| ) |
| logger = logging.getLogger(__name__) |
|
|
| |
| NUM_LAYERS = 24 |
| LAYER_TYPES = [ |
| "linear", "linear", "linear", "full", |
| "linear", "linear", "linear", "full", |
| "linear", "linear", "linear", "full", |
| "linear", "linear", "linear", "full", |
| "linear", "linear", "linear", "full", |
| "linear", "linear", "linear", "full", |
| ] |
| LINEAR_QKV_DIM = 6144 |
| LINEAR_CONV_KERNEL = 4 |
| LINEAR_NUM_HEADS = 16 |
| LINEAR_K_HEAD_DIM = 128 |
| LINEAR_V_HEAD_DIM = 128 |
| FULL_ATTN_NUM_KV_HEADS = 2 |
| FULL_ATTN_HEAD_DIM = 256 |
|
|
| MODEL_ID = "Qwen/Qwen3.5-0.8B" |
|
|
|
|
| def create_initial_kv_cache(kv_cache_max_len, batch_size=1): |
| """Create zero-initialized KV cache arrays matching the model's per-layer shapes.""" |
| kv = {} |
| for i in range(NUM_LAYERS): |
| if LAYER_TYPES[i] == "linear": |
| kv[f"kv_cache_k_{i}"] = np.zeros( |
| (batch_size, LINEAR_QKV_DIM, LINEAR_CONV_KERNEL - 1), |
| dtype=np.float32, |
| ) |
| kv[f"kv_cache_v_{i}"] = np.zeros( |
| (batch_size, LINEAR_NUM_HEADS, LINEAR_K_HEAD_DIM, LINEAR_V_HEAD_DIM), |
| dtype=np.float32, |
| ) |
| else: |
| kv[f"kv_cache_k_{i}"] = np.zeros( |
| (batch_size, kv_cache_max_len, FULL_ATTN_NUM_KV_HEADS, FULL_ATTN_HEAD_DIM), |
| dtype=np.float32, |
| ) |
| kv[f"kv_cache_v_{i}"] = np.zeros( |
| (batch_size, kv_cache_max_len, FULL_ATTN_NUM_KV_HEADS, FULL_ATTN_HEAD_DIM), |
| dtype=np.float32, |
| ) |
| return kv |
|
|
|
|
| def find_prefill_signature(signatures, seq_len): |
| """Find the best prefill signature for the given sequence length.""" |
| prefill_sigs = sorted( |
| [s for s in signatures if s.startswith("prefill_")], |
| key=lambda s: int(s.split("_")[1]), |
| ) |
| if not prefill_sigs: |
| raise ValueError("No prefill signatures found in model") |
|
|
| for sig in prefill_sigs: |
| sig_len = int(sig.split("_")[1]) |
| if sig_len >= seq_len: |
| return sig, sig_len |
|
|
| |
| largest = prefill_sigs[-1] |
| return largest, int(largest.split("_")[1]) |
|
|
|
|
| def generate(model_path, prompt, max_new_tokens, kv_cache_max_len): |
| """Run text generation with the TFLite model.""" |
| |
| logger.info("Loading tokenizer from: %s", MODEL_ID) |
| tokenizer = transformers.AutoTokenizer.from_pretrained( |
| MODEL_ID, trust_remote_code=True |
| ) |
|
|
| |
| input_ids = tokenizer.encode(prompt) |
| logger.info("Prompt: %s", prompt) |
| logger.info("Token count: %d", len(input_ids)) |
|
|
| |
| logger.info("Loading TFLite model from: %s", model_path) |
| t0 = time.time() |
| interp = tfl_interpreter.Interpreter(model_path=model_path) |
| interp.allocate_tensors() |
| logger.info("Model loaded in %.1fs", time.time() - t0) |
|
|
| signatures = interp.get_signature_list() |
| logger.info("Available signatures: %s", list(signatures.keys())) |
|
|
| |
| kv_cache = create_initial_kv_cache(kv_cache_max_len) |
|
|
| |
| sig_name, sig_len = find_prefill_signature(signatures, len(input_ids)) |
| logger.info("Using prefill signature: %s (padding %d -> %d)", sig_name, len(input_ids), sig_len) |
|
|
| |
| padded_ids = input_ids + [0] * (sig_len - len(input_ids)) |
| tokens = np.array([padded_ids], dtype=np.int32) |
| input_pos = np.arange(sig_len, dtype=np.int32) |
|
|
| prefill_runner = interp.get_signature_runner(sig_name) |
| t0 = time.time() |
| prefill_out = prefill_runner(tokens=tokens, input_pos=input_pos, **kv_cache) |
| prefill_time = time.time() - t0 |
| logger.info("Prefill done in %.2fs", prefill_time) |
|
|
| |
| for key in kv_cache: |
| if key in prefill_out: |
| kv_cache[key] = prefill_out[key] |
|
|
| |
| |
| |
| |
| decode_runner = interp.get_signature_runner("decode") |
| generated_ids = list(input_ids) |
| current_pos = sig_len |
|
|
| logger.info("Starting decode (max %d tokens)...", max_new_tokens) |
| print(f"\n--- Generated text ---\n{prompt}", end="", flush=True) |
|
|
| t0 = time.time() |
| for step in range(max_new_tokens): |
| |
| tok = np.array([[generated_ids[-1]]], dtype=np.int32) |
| pos = np.array([current_pos], dtype=np.int32) |
| decode_out = decode_runner(tokens=tok, input_pos=pos, **kv_cache) |
|
|
| |
| for key in kv_cache: |
| if key in decode_out: |
| kv_cache[key] = decode_out[key] |
|
|
| next_token = int(np.argmax(decode_out["logits"][0, -1])) |
| generated_ids.append(next_token) |
| current_pos += 1 |
|
|
| |
| word = tokenizer.decode([next_token]) |
| print(word, end="", flush=True) |
|
|
| |
| if next_token == tokenizer.eos_token_id: |
| break |
|
|
| decode_time = time.time() - t0 |
| num_decoded = len(generated_ids) - len(input_ids) |
| print(f"\n\n--- Stats ---") |
| print(f"Prefill: {prefill_time:.2f}s ({len(input_ids)} tokens)") |
| print(f"Decode: {decode_time:.2f}s ({num_decoded} tokens, {num_decoded/decode_time:.1f} tok/s)") |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser(description="TFLite inference for Qwen3.5-0.8B") |
| parser.add_argument( |
| "--model_path", |
| default=None, |
| help="Path to .tflite model file", |
| ) |
| parser.add_argument( |
| "--prompt", |
| default="What is the meaning of life?", |
| help="Input prompt", |
| ) |
| parser.add_argument( |
| "--max_new_tokens", |
| type=int, |
| default=50, |
| help="Maximum tokens to generate", |
| ) |
| parser.add_argument( |
| "--kv_cache_max_len", |
| type=int, |
| default=2048, |
| help="KV cache max length (must match exported model)", |
| ) |
| args = parser.parse_args() |
|
|
| |
| if args.model_path is None: |
| files = glob.glob("output/**/*.tflite", recursive=True) |
| if files: |
| args.model_path = max(files, key=lambda f: __import__("os").path.getmtime(f)) |
| logger.info("Auto-found model: %s", args.model_path) |
| else: |
| raise FileNotFoundError("No .tflite files found in output/") |
|
|
| generate(args.model_path, args.prompt, args.max_new_tokens, args.kv_cache_max_len) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|