File size: 9,268 Bytes
4610862 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 | """
Self-contained Gemma 4 E2B INT4 runner for Raspberry Pi 5 (or any ARM64).
Loads ONE .pte (external-cache variant), tokenizes a prompt with the
Gemma chat template, runs token-by-token (prompt feed + decode) threading
KV cache tensors across calls, prints generated text + timing.
Designed to run on the Pi with NO project codebase — just the .pte,
the tokenizer files, and this script. Only Python deps:
pip install executorch transformers
Files expected next to this script (or pass paths via flags):
gemma4_e2b_text_int4_extcache.pte
tokenizer/ # dir with tokenizer.json + tokenizer_config.json + chat_template.jinja
Usage:
python pi_runner.py "The capital of France is"
python pi_runner.py "Why is the sky blue?" --max-new-tokens 50
python pi_runner.py "Hello" --verify # asserts output matches reference
"""
import argparse
import os
import time
import torch
from transformers import AutoTokenizer
from executorch.runtime import Runtime, Verification
HERE = os.path.dirname(os.path.abspath(__file__))
DEFAULT_PTE = os.path.join(HERE, "gemma4_e2b_text_int4_extcache.pte")
DEFAULT_TOK = os.path.join(HERE, "tokenizer")
MAX_CACHE_LEN = 512
MASK_LEN = MAX_CACHE_LEN - 1 # 511; .pte specialized to this upper bound
DTYPE = torch.float32 # matches the model's quantize-time dtype
# Hardcoded cache layout for Gemma 4 E2B. Matches what
# scripts/_external_cache.py:compute_layer_specs derives from the model
# config. 35 decoder layers minus num_kv_shared_layers=20 = 15 cache layers.
# Layer-type pattern (repeats every 5): [sliding, sliding, sliding, sliding, full].
# Sliding layers: head_dim=256. Full layers: global_head_dim=512.
GEMMA4_E2B_LAYER_SHAPES = [
# (head_dim, is_sliding)
(256, True), (256, True), (256, True), (256, True), (512, False), # layers 0-4
(256, True), (256, True), (256, True), (256, True), (512, False), # layers 5-9
(256, True), (256, True), (256, True), (256, True), (512, False), # layers 10-14
]
NUM_KV_HEADS = 1
BATCH = 1
# Reference for --verify mode (FP32/INT4 token-id sequence for "The capital of France is")
REFERENCE_PROMPT = "The capital of France is"
REFERENCE_IDS = [818, 5279, 529, 7001, 563, 5213, 50429, 84750, 106]
REFERENCE_TEXT = "The capital of France is **Paris**."
# Gemma 4 end-of-turn token id (model stops here in chat)
EOS_TOKEN_IDS = {106, 1, 2} # <end_of_turn>, <eos>, <bos>-as-sentinel
def allocate_cache_tensors():
"""Allocate one set of K, V, cumulative_length tensors per cache layer."""
k_caches, v_caches, cumlen_caches = [], [], []
for head_dim, _is_sliding in GEMMA4_E2B_LAYER_SHAPES:
shape = (BATCH, NUM_KV_HEADS, MAX_CACHE_LEN, head_dim)
k_caches.append(torch.zeros(shape, dtype=DTYPE))
v_caches.append(torch.zeros(shape, dtype=DTYPE))
cumlen_caches.append(torch.zeros(1, dtype=torch.int64))
return k_caches, v_caches, cumlen_caches
def step(method, token_id, pos, k_caches, v_caches, cumlen_caches):
"""One forward call: feed `token_id` at position `pos`, get logits +
updated cache tensors back."""
input_ids = torch.tensor([[token_id]], dtype=torch.long)
attention_mask = torch.zeros(1, MASK_LEN, dtype=torch.long)
attention_mask[:, :pos + 1] = 1
position_ids = torch.tensor([[pos]], dtype=torch.long)
cache_position = torch.tensor([pos], dtype=torch.long)
# The .pte's execute() takes flat positional inputs.
# Order matches the wrapper's forward signature:
# input_ids, attention_mask, position_ids, cache_position, *k_caches, *v_caches, *cumlen_caches
args = (input_ids, attention_mask, position_ids, cache_position,
*k_caches, *v_caches, *cumlen_caches)
if os.environ.get("DEBUG_SHAPES"):
for i, a in enumerate(args):
if hasattr(a, "shape"):
print(f" arg[{i:2d}]: shape={tuple(a.shape)} dtype={a.dtype}", flush=True)
outputs = method.execute(args)
# The .pte emits 91 outputs:
# [0..14] K mutations (auto-emitted by torch.export)
# [15..29] V mutations
# [30..44] cumlen mutations
# [45] logits
# [46..60] K (from wrapper's explicit return — same tensors)
# [61..75] V (from wrapper's explicit return)
# [76..90] cumlen (from wrapper's explicit return)
# Either copy works; we use the explicit-return half because indices
# align with the (logits, k, v, cumlen) ordering the wrapper declared.
n = len(GEMMA4_E2B_LAYER_SHAPES)
logits = outputs[45]
base = 46
k_new = list(outputs[base:base + n])
v_new = list(outputs[base + n:base + 2 * n])
cumlen_new = list(outputs[base + 2 * n:base + 3 * n])
return logits, k_new, v_new, cumlen_new
def main():
parser = argparse.ArgumentParser()
parser.add_argument("prompt", nargs="?", default=REFERENCE_PROMPT,
help=f"Prompt text (default: {REFERENCE_PROMPT!r})")
parser.add_argument("--pte", default=DEFAULT_PTE, help="Path to .pte file")
parser.add_argument("--tokenizer", default=DEFAULT_TOK, help="Path to tokenizer dir")
parser.add_argument("--max-new-tokens", type=int, default=20)
parser.add_argument("--verify", action="store_true",
help="Assert prompt+output match the reference (smoke test)")
args = parser.parse_args()
if args.verify:
args.prompt = REFERENCE_PROMPT
args.max_new_tokens = max(args.max_new_tokens, len(REFERENCE_IDS))
print(f"Loading tokenizer from {args.tokenizer}...")
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer)
print(f"Tokenizing prompt: {args.prompt!r}")
messages = [{"role": "user", "content": [{"type": "text", "text": args.prompt}]}]
enc = tokenizer.apply_chat_template(
messages, add_generation_prompt=True, tokenize=True,
return_dict=True, return_tensors="pt",
)
prompt_ids = enc["input_ids"][0].tolist()
n_prompt = len(prompt_ids)
if n_prompt + args.max_new_tokens > MASK_LEN:
raise SystemExit(
f"prompt ({n_prompt}) + max_new_tokens ({args.max_new_tokens}) "
f"exceeds mask_len ({MASK_LEN})"
)
print(f" prompt_len = {n_prompt}")
print(f"\nLoading {args.pte}...")
t0 = time.time()
rt = Runtime.get()
program = rt.load_program(args.pte, verification=Verification.Minimal)
method = program.load_method("forward")
print(f" loaded in {time.time() - t0:.1f}s")
print("Allocating cache tensors...")
k_caches, v_caches, cumlen_caches = allocate_cache_tensors()
cache_mb = sum(t.numel() * t.element_size() for t in k_caches + v_caches) / 1e6
print(f" total cache size: {cache_mb:.1f} MB across {len(k_caches)} layers")
# --- Prompt token-by-token feed ("slow prefill") ---
print(f"\nFeeding {n_prompt} prompt tokens (token-by-token; no batched prefill in this design)...")
t_prefill_start = time.time()
last_logits = None
for i, tok in enumerate(prompt_ids):
last_logits, k_caches, v_caches, cumlen_caches = step(
method, tok, i, k_caches, v_caches, cumlen_caches
)
t_prefill = time.time() - t_prefill_start
print(f" prompt feed: {t_prefill:.2f}s ({n_prompt} tokens, "
f"{n_prompt / t_prefill:.2f} tok/s, ttft equivalent)")
# Next-token prediction from last prompt position's logits
next_id = int(last_logits[0, -1].argmax())
generated = [next_id]
print(f" first generated token: id={next_id} text={tokenizer.decode([next_id])!r}")
# --- Decode loop ---
print(f"\nDecoding up to {args.max_new_tokens - 1} more tokens...")
t_decode_start = time.time()
n_decoded = 1
for step_idx in range(args.max_new_tokens - 1):
if next_id in EOS_TOKEN_IDS:
print(f" hit EOS (id={next_id}) at decode step {step_idx}")
break
pos = n_prompt + step_idx # position of the token we just produced
last_logits, k_caches, v_caches, cumlen_caches = step(
method, next_id, pos, k_caches, v_caches, cumlen_caches
)
next_id = int(last_logits[0, -1].argmax())
generated.append(next_id)
n_decoded += 1
t_decode = time.time() - t_decode_start
text = tokenizer.decode(generated, skip_special_tokens=True)
print(f"\nGenerated ({len(generated)} tokens): {text!r}")
print(f"\n=== Timing ===")
print(f" prompt feed: {t_prefill*1000:7.0f} ms ({n_prompt} tok @ {n_prompt/t_prefill:5.2f} tok/s)")
print(f" decode: {t_decode*1000:7.0f} ms ({n_decoded} tok @ {n_decoded/t_decode:5.2f} tok/s)")
print(f" total: {(t_prefill + t_decode)*1000:7.0f} ms")
if args.verify:
compare_len = min(len(generated), len(REFERENCE_IDS))
match = generated[:compare_len] == REFERENCE_IDS[:compare_len]
print(f"\n=== Verify ===")
print(f" reference text: {REFERENCE_TEXT!r}")
print(f" generated text: {text!r}")
print(f" reference ids: {REFERENCE_IDS[:compare_len]}")
print(f" generated ids: {generated[:compare_len]}")
print(f" RESULT: {'PASS' if match else 'FAIL'}")
raise SystemExit(0 if match else 1)
if __name__ == "__main__":
main()
|