| """ |
| Self-contained Gemma 4 E2B INT4 runner for Raspberry Pi 5 (or any ARM64). |
| |
| Loads ONE .pte (external-cache variant), tokenizes a prompt with the |
| Gemma chat template, runs token-by-token (prompt feed + decode) threading |
| KV cache tensors across calls, prints generated text + timing. |
| |
| Designed to run on the Pi with NO project codebase — just the .pte, |
| the tokenizer files, and this script. Only Python deps: |
| pip install executorch transformers |
| |
| Files expected next to this script (or pass paths via flags): |
| gemma4_e2b_text_int4_extcache.pte |
| tokenizer/ # dir with tokenizer.json + tokenizer_config.json + chat_template.jinja |
| |
| Usage: |
| python pi_runner.py "The capital of France is" |
| python pi_runner.py "Why is the sky blue?" --max-new-tokens 50 |
| python pi_runner.py "Hello" --verify # asserts output matches reference |
| """ |
|
|
| import argparse |
| import os |
| import time |
|
|
| import torch |
| from transformers import AutoTokenizer |
| from executorch.runtime import Runtime, Verification |
|
|
| HERE = os.path.dirname(os.path.abspath(__file__)) |
| DEFAULT_PTE = os.path.join(HERE, "gemma4_e2b_text_int4_extcache.pte") |
| DEFAULT_TOK = os.path.join(HERE, "tokenizer") |
|
|
| MAX_CACHE_LEN = 512 |
| MASK_LEN = MAX_CACHE_LEN - 1 |
| DTYPE = torch.float32 |
|
|
| |
| |
| |
| |
| |
| GEMMA4_E2B_LAYER_SHAPES = [ |
| |
| (256, True), (256, True), (256, True), (256, True), (512, False), |
| (256, True), (256, True), (256, True), (256, True), (512, False), |
| (256, True), (256, True), (256, True), (256, True), (512, False), |
| ] |
| NUM_KV_HEADS = 1 |
| BATCH = 1 |
|
|
| |
| REFERENCE_PROMPT = "The capital of France is" |
| REFERENCE_IDS = [818, 5279, 529, 7001, 563, 5213, 50429, 84750, 106] |
| REFERENCE_TEXT = "The capital of France is **Paris**." |
|
|
| |
| EOS_TOKEN_IDS = {106, 1, 2} |
|
|
|
|
| def allocate_cache_tensors(): |
| """Allocate one set of K, V, cumulative_length tensors per cache layer.""" |
| k_caches, v_caches, cumlen_caches = [], [], [] |
| for head_dim, _is_sliding in GEMMA4_E2B_LAYER_SHAPES: |
| shape = (BATCH, NUM_KV_HEADS, MAX_CACHE_LEN, head_dim) |
| k_caches.append(torch.zeros(shape, dtype=DTYPE)) |
| v_caches.append(torch.zeros(shape, dtype=DTYPE)) |
| cumlen_caches.append(torch.zeros(1, dtype=torch.int64)) |
| return k_caches, v_caches, cumlen_caches |
|
|
|
|
| def step(method, token_id, pos, k_caches, v_caches, cumlen_caches): |
| """One forward call: feed `token_id` at position `pos`, get logits + |
| updated cache tensors back.""" |
| input_ids = torch.tensor([[token_id]], dtype=torch.long) |
| attention_mask = torch.zeros(1, MASK_LEN, dtype=torch.long) |
| attention_mask[:, :pos + 1] = 1 |
| position_ids = torch.tensor([[pos]], dtype=torch.long) |
| cache_position = torch.tensor([pos], dtype=torch.long) |
|
|
| |
| |
| |
| args = (input_ids, attention_mask, position_ids, cache_position, |
| *k_caches, *v_caches, *cumlen_caches) |
| if os.environ.get("DEBUG_SHAPES"): |
| for i, a in enumerate(args): |
| if hasattr(a, "shape"): |
| print(f" arg[{i:2d}]: shape={tuple(a.shape)} dtype={a.dtype}", flush=True) |
| outputs = method.execute(args) |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| n = len(GEMMA4_E2B_LAYER_SHAPES) |
| logits = outputs[45] |
| base = 46 |
| k_new = list(outputs[base:base + n]) |
| v_new = list(outputs[base + n:base + 2 * n]) |
| cumlen_new = list(outputs[base + 2 * n:base + 3 * n]) |
| return logits, k_new, v_new, cumlen_new |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser() |
| parser.add_argument("prompt", nargs="?", default=REFERENCE_PROMPT, |
| help=f"Prompt text (default: {REFERENCE_PROMPT!r})") |
| parser.add_argument("--pte", default=DEFAULT_PTE, help="Path to .pte file") |
| parser.add_argument("--tokenizer", default=DEFAULT_TOK, help="Path to tokenizer dir") |
| parser.add_argument("--max-new-tokens", type=int, default=20) |
| parser.add_argument("--verify", action="store_true", |
| help="Assert prompt+output match the reference (smoke test)") |
| args = parser.parse_args() |
|
|
| if args.verify: |
| args.prompt = REFERENCE_PROMPT |
| args.max_new_tokens = max(args.max_new_tokens, len(REFERENCE_IDS)) |
|
|
| print(f"Loading tokenizer from {args.tokenizer}...") |
| tokenizer = AutoTokenizer.from_pretrained(args.tokenizer) |
|
|
| print(f"Tokenizing prompt: {args.prompt!r}") |
| messages = [{"role": "user", "content": [{"type": "text", "text": args.prompt}]}] |
| enc = tokenizer.apply_chat_template( |
| messages, add_generation_prompt=True, tokenize=True, |
| return_dict=True, return_tensors="pt", |
| ) |
| prompt_ids = enc["input_ids"][0].tolist() |
| n_prompt = len(prompt_ids) |
| if n_prompt + args.max_new_tokens > MASK_LEN: |
| raise SystemExit( |
| f"prompt ({n_prompt}) + max_new_tokens ({args.max_new_tokens}) " |
| f"exceeds mask_len ({MASK_LEN})" |
| ) |
| print(f" prompt_len = {n_prompt}") |
|
|
| print(f"\nLoading {args.pte}...") |
| t0 = time.time() |
| rt = Runtime.get() |
| program = rt.load_program(args.pte, verification=Verification.Minimal) |
| method = program.load_method("forward") |
| print(f" loaded in {time.time() - t0:.1f}s") |
|
|
| print("Allocating cache tensors...") |
| k_caches, v_caches, cumlen_caches = allocate_cache_tensors() |
| cache_mb = sum(t.numel() * t.element_size() for t in k_caches + v_caches) / 1e6 |
| print(f" total cache size: {cache_mb:.1f} MB across {len(k_caches)} layers") |
|
|
| |
| print(f"\nFeeding {n_prompt} prompt tokens (token-by-token; no batched prefill in this design)...") |
| t_prefill_start = time.time() |
| last_logits = None |
| for i, tok in enumerate(prompt_ids): |
| last_logits, k_caches, v_caches, cumlen_caches = step( |
| method, tok, i, k_caches, v_caches, cumlen_caches |
| ) |
| t_prefill = time.time() - t_prefill_start |
| print(f" prompt feed: {t_prefill:.2f}s ({n_prompt} tokens, " |
| f"{n_prompt / t_prefill:.2f} tok/s, ttft equivalent)") |
|
|
| |
| next_id = int(last_logits[0, -1].argmax()) |
| generated = [next_id] |
| print(f" first generated token: id={next_id} text={tokenizer.decode([next_id])!r}") |
|
|
| |
| print(f"\nDecoding up to {args.max_new_tokens - 1} more tokens...") |
| t_decode_start = time.time() |
| n_decoded = 1 |
| for step_idx in range(args.max_new_tokens - 1): |
| if next_id in EOS_TOKEN_IDS: |
| print(f" hit EOS (id={next_id}) at decode step {step_idx}") |
| break |
| pos = n_prompt + step_idx |
| last_logits, k_caches, v_caches, cumlen_caches = step( |
| method, next_id, pos, k_caches, v_caches, cumlen_caches |
| ) |
| next_id = int(last_logits[0, -1].argmax()) |
| generated.append(next_id) |
| n_decoded += 1 |
| t_decode = time.time() - t_decode_start |
|
|
| text = tokenizer.decode(generated, skip_special_tokens=True) |
| print(f"\nGenerated ({len(generated)} tokens): {text!r}") |
| print(f"\n=== Timing ===") |
| print(f" prompt feed: {t_prefill*1000:7.0f} ms ({n_prompt} tok @ {n_prompt/t_prefill:5.2f} tok/s)") |
| print(f" decode: {t_decode*1000:7.0f} ms ({n_decoded} tok @ {n_decoded/t_decode:5.2f} tok/s)") |
| print(f" total: {(t_prefill + t_decode)*1000:7.0f} ms") |
|
|
| if args.verify: |
| compare_len = min(len(generated), len(REFERENCE_IDS)) |
| match = generated[:compare_len] == REFERENCE_IDS[:compare_len] |
| print(f"\n=== Verify ===") |
| print(f" reference text: {REFERENCE_TEXT!r}") |
| print(f" generated text: {text!r}") |
| print(f" reference ids: {REFERENCE_IDS[:compare_len]}") |
| print(f" generated ids: {generated[:compare_len]}") |
| print(f" RESULT: {'PASS' if match else 'FAIL'}") |
| raise SystemExit(0 if match else 1) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|