Qwen3.5-0.8B-LiteRT / inference_tflite.py
GabrieleConte's picture
Duplicate from g-ntovas/Qwen3.5-0.8B-LiteRT
766ec4a
"""
Run text generation inference on the exported Qwen3.5-0.8B TFLite model.
Usage:
python inference_tflite.py --model_path output/qwen35_0.8b/qwen35_q8_ekv2048.tflite
python inference_tflite.py --prompt "Explain gravity" --max_new_tokens 100
"""
import argparse
import glob
import logging
import time
import numpy as np
import transformers
from ai_edge_litert import interpreter as tfl_interpreter
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(message)s",
)
logger = logging.getLogger(__name__)
# Architecture constants (must match qwen35_model.py)
NUM_LAYERS = 24
LAYER_TYPES = [
"linear", "linear", "linear", "full",
"linear", "linear", "linear", "full",
"linear", "linear", "linear", "full",
"linear", "linear", "linear", "full",
"linear", "linear", "linear", "full",
"linear", "linear", "linear", "full",
]
LINEAR_QKV_DIM = 6144
LINEAR_CONV_KERNEL = 4
LINEAR_NUM_HEADS = 16
LINEAR_K_HEAD_DIM = 128
LINEAR_V_HEAD_DIM = 128
FULL_ATTN_NUM_KV_HEADS = 2
FULL_ATTN_HEAD_DIM = 256
MODEL_ID = "Qwen/Qwen3.5-0.8B"
def create_initial_kv_cache(kv_cache_max_len, batch_size=1):
"""Create zero-initialized KV cache arrays matching the model's per-layer shapes."""
kv = {}
for i in range(NUM_LAYERS):
if LAYER_TYPES[i] == "linear":
kv[f"kv_cache_k_{i}"] = np.zeros(
(batch_size, LINEAR_QKV_DIM, LINEAR_CONV_KERNEL - 1),
dtype=np.float32,
)
kv[f"kv_cache_v_{i}"] = np.zeros(
(batch_size, LINEAR_NUM_HEADS, LINEAR_K_HEAD_DIM, LINEAR_V_HEAD_DIM),
dtype=np.float32,
)
else:
kv[f"kv_cache_k_{i}"] = np.zeros(
(batch_size, kv_cache_max_len, FULL_ATTN_NUM_KV_HEADS, FULL_ATTN_HEAD_DIM),
dtype=np.float32,
)
kv[f"kv_cache_v_{i}"] = np.zeros(
(batch_size, kv_cache_max_len, FULL_ATTN_NUM_KV_HEADS, FULL_ATTN_HEAD_DIM),
dtype=np.float32,
)
return kv
def find_prefill_signature(signatures, seq_len):
"""Find the best prefill signature for the given sequence length."""
prefill_sigs = sorted(
[s for s in signatures if s.startswith("prefill_")],
key=lambda s: int(s.split("_")[1]),
)
if not prefill_sigs:
raise ValueError("No prefill signatures found in model")
for sig in prefill_sigs:
sig_len = int(sig.split("_")[1])
if sig_len >= seq_len:
return sig, sig_len
# Use largest available
largest = prefill_sigs[-1]
return largest, int(largest.split("_")[1])
def generate(model_path, prompt, max_new_tokens, kv_cache_max_len):
"""Run text generation with the TFLite model."""
# Load tokenizer
logger.info("Loading tokenizer from: %s", MODEL_ID)
tokenizer = transformers.AutoTokenizer.from_pretrained(
MODEL_ID, trust_remote_code=True
)
# Tokenize prompt
input_ids = tokenizer.encode(prompt)
logger.info("Prompt: %s", prompt)
logger.info("Token count: %d", len(input_ids))
# Load TFLite model
logger.info("Loading TFLite model from: %s", model_path)
t0 = time.time()
interp = tfl_interpreter.Interpreter(model_path=model_path)
interp.allocate_tensors()
logger.info("Model loaded in %.1fs", time.time() - t0)
signatures = interp.get_signature_list()
logger.info("Available signatures: %s", list(signatures.keys()))
# Initialize KV cache
kv_cache = create_initial_kv_cache(kv_cache_max_len)
# --- Prefill phase ---
sig_name, sig_len = find_prefill_signature(signatures, len(input_ids))
logger.info("Using prefill signature: %s (padding %d -> %d)", sig_name, len(input_ids), sig_len)
# Pad input to match signature length
padded_ids = input_ids + [0] * (sig_len - len(input_ids))
tokens = np.array([padded_ids], dtype=np.int32)
input_pos = np.arange(sig_len, dtype=np.int32)
prefill_runner = interp.get_signature_runner(sig_name)
t0 = time.time()
prefill_out = prefill_runner(tokens=tokens, input_pos=input_pos, **kv_cache)
prefill_time = time.time() - t0
logger.info("Prefill done in %.2fs", prefill_time)
# Update KV cache from prefill output
for key in kv_cache:
if key in prefill_out:
kv_cache[key] = prefill_out[key]
# --- Decode phase ---
# Prefill processed sig_len tokens (including padding). Next decode
# position is sig_len. We feed the last real token to get the first
# generated token.
decode_runner = interp.get_signature_runner("decode")
generated_ids = list(input_ids)
current_pos = sig_len # continue after prefill
logger.info("Starting decode (max %d tokens)...", max_new_tokens)
print(f"\n--- Generated text ---\n{prompt}", end="", flush=True)
t0 = time.time()
for step in range(max_new_tokens):
# Feed last token, get next
tok = np.array([[generated_ids[-1]]], dtype=np.int32)
pos = np.array([current_pos], dtype=np.int32)
decode_out = decode_runner(tokens=tok, input_pos=pos, **kv_cache)
# Update KV cache
for key in kv_cache:
if key in decode_out:
kv_cache[key] = decode_out[key]
next_token = int(np.argmax(decode_out["logits"][0, -1]))
generated_ids.append(next_token)
current_pos += 1
# Print token
word = tokenizer.decode([next_token])
print(word, end="", flush=True)
# Stop on EOS
if next_token == tokenizer.eos_token_id:
break
decode_time = time.time() - t0
num_decoded = len(generated_ids) - len(input_ids)
print(f"\n\n--- Stats ---")
print(f"Prefill: {prefill_time:.2f}s ({len(input_ids)} tokens)")
print(f"Decode: {decode_time:.2f}s ({num_decoded} tokens, {num_decoded/decode_time:.1f} tok/s)")
def main():
parser = argparse.ArgumentParser(description="TFLite inference for Qwen3.5-0.8B")
parser.add_argument(
"--model_path",
default=None,
help="Path to .tflite model file",
)
parser.add_argument(
"--prompt",
default="What is the meaning of life?",
help="Input prompt",
)
parser.add_argument(
"--max_new_tokens",
type=int,
default=50,
help="Maximum tokens to generate",
)
parser.add_argument(
"--kv_cache_max_len",
type=int,
default=2048,
help="KV cache max length (must match exported model)",
)
args = parser.parse_args()
# Auto-find model if not specified
if args.model_path is None:
files = glob.glob("output/**/*.tflite", recursive=True)
if files:
args.model_path = max(files, key=lambda f: __import__("os").path.getmtime(f))
logger.info("Auto-found model: %s", args.model_path)
else:
raise FileNotFoundError("No .tflite files found in output/")
generate(args.model_path, args.prompt, args.max_new_tokens, args.kv_cache_max_len)
if __name__ == "__main__":
main()