File size: 9,268 Bytes
4610862
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
"""
Self-contained Gemma 4 E2B INT4 runner for Raspberry Pi 5 (or any ARM64).

Loads ONE .pte (external-cache variant), tokenizes a prompt with the
Gemma chat template, runs token-by-token (prompt feed + decode) threading
KV cache tensors across calls, prints generated text + timing.

Designed to run on the Pi with NO project codebase — just the .pte,
the tokenizer files, and this script. Only Python deps:
    pip install executorch transformers

Files expected next to this script (or pass paths via flags):
    gemma4_e2b_text_int4_extcache.pte
    tokenizer/                          # dir with tokenizer.json + tokenizer_config.json + chat_template.jinja

Usage:
    python pi_runner.py "The capital of France is"
    python pi_runner.py "Why is the sky blue?" --max-new-tokens 50
    python pi_runner.py "Hello" --verify  # asserts output matches reference
"""

import argparse
import os
import time

import torch
from transformers import AutoTokenizer
from executorch.runtime import Runtime, Verification

HERE = os.path.dirname(os.path.abspath(__file__))
DEFAULT_PTE = os.path.join(HERE, "gemma4_e2b_text_int4_extcache.pte")
DEFAULT_TOK = os.path.join(HERE, "tokenizer")

MAX_CACHE_LEN = 512
MASK_LEN = MAX_CACHE_LEN - 1  # 511; .pte specialized to this upper bound
DTYPE = torch.float32  # matches the model's quantize-time dtype

# Hardcoded cache layout for Gemma 4 E2B. Matches what
# scripts/_external_cache.py:compute_layer_specs derives from the model
# config. 35 decoder layers minus num_kv_shared_layers=20 = 15 cache layers.
# Layer-type pattern (repeats every 5): [sliding, sliding, sliding, sliding, full].
# Sliding layers: head_dim=256. Full layers: global_head_dim=512.
GEMMA4_E2B_LAYER_SHAPES = [
    # (head_dim, is_sliding)
    (256, True),  (256, True),  (256, True),  (256, True),  (512, False),  # layers 0-4
    (256, True),  (256, True),  (256, True),  (256, True),  (512, False),  # layers 5-9
    (256, True),  (256, True),  (256, True),  (256, True),  (512, False),  # layers 10-14
]
NUM_KV_HEADS = 1
BATCH = 1

# Reference for --verify mode (FP32/INT4 token-id sequence for "The capital of France is")
REFERENCE_PROMPT = "The capital of France is"
REFERENCE_IDS = [818, 5279, 529, 7001, 563, 5213, 50429, 84750, 106]
REFERENCE_TEXT = "The capital of France is **Paris**."

# Gemma 4 end-of-turn token id (model stops here in chat)
EOS_TOKEN_IDS = {106, 1, 2}  # <end_of_turn>, <eos>, <bos>-as-sentinel


def allocate_cache_tensors():
    """Allocate one set of K, V, cumulative_length tensors per cache layer."""
    k_caches, v_caches, cumlen_caches = [], [], []
    for head_dim, _is_sliding in GEMMA4_E2B_LAYER_SHAPES:
        shape = (BATCH, NUM_KV_HEADS, MAX_CACHE_LEN, head_dim)
        k_caches.append(torch.zeros(shape, dtype=DTYPE))
        v_caches.append(torch.zeros(shape, dtype=DTYPE))
        cumlen_caches.append(torch.zeros(1, dtype=torch.int64))
    return k_caches, v_caches, cumlen_caches


def step(method, token_id, pos, k_caches, v_caches, cumlen_caches):
    """One forward call: feed `token_id` at position `pos`, get logits +
    updated cache tensors back."""
    input_ids = torch.tensor([[token_id]], dtype=torch.long)
    attention_mask = torch.zeros(1, MASK_LEN, dtype=torch.long)
    attention_mask[:, :pos + 1] = 1
    position_ids = torch.tensor([[pos]], dtype=torch.long)
    cache_position = torch.tensor([pos], dtype=torch.long)

    # The .pte's execute() takes flat positional inputs.
    # Order matches the wrapper's forward signature:
    #   input_ids, attention_mask, position_ids, cache_position, *k_caches, *v_caches, *cumlen_caches
    args = (input_ids, attention_mask, position_ids, cache_position,
            *k_caches, *v_caches, *cumlen_caches)
    if os.environ.get("DEBUG_SHAPES"):
        for i, a in enumerate(args):
            if hasattr(a, "shape"):
                print(f"    arg[{i:2d}]: shape={tuple(a.shape)} dtype={a.dtype}", flush=True)
    outputs = method.execute(args)
    # The .pte emits 91 outputs:
    #   [0..14]   K mutations (auto-emitted by torch.export)
    #   [15..29]  V mutations
    #   [30..44]  cumlen mutations
    #   [45]      logits
    #   [46..60]  K (from wrapper's explicit return — same tensors)
    #   [61..75]  V (from wrapper's explicit return)
    #   [76..90]  cumlen (from wrapper's explicit return)
    # Either copy works; we use the explicit-return half because indices
    # align with the (logits, k, v, cumlen) ordering the wrapper declared.
    n = len(GEMMA4_E2B_LAYER_SHAPES)
    logits = outputs[45]
    base = 46
    k_new = list(outputs[base:base + n])
    v_new = list(outputs[base + n:base + 2 * n])
    cumlen_new = list(outputs[base + 2 * n:base + 3 * n])
    return logits, k_new, v_new, cumlen_new


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("prompt", nargs="?", default=REFERENCE_PROMPT,
                        help=f"Prompt text (default: {REFERENCE_PROMPT!r})")
    parser.add_argument("--pte", default=DEFAULT_PTE, help="Path to .pte file")
    parser.add_argument("--tokenizer", default=DEFAULT_TOK, help="Path to tokenizer dir")
    parser.add_argument("--max-new-tokens", type=int, default=20)
    parser.add_argument("--verify", action="store_true",
                        help="Assert prompt+output match the reference (smoke test)")
    args = parser.parse_args()

    if args.verify:
        args.prompt = REFERENCE_PROMPT
        args.max_new_tokens = max(args.max_new_tokens, len(REFERENCE_IDS))

    print(f"Loading tokenizer from {args.tokenizer}...")
    tokenizer = AutoTokenizer.from_pretrained(args.tokenizer)

    print(f"Tokenizing prompt: {args.prompt!r}")
    messages = [{"role": "user", "content": [{"type": "text", "text": args.prompt}]}]
    enc = tokenizer.apply_chat_template(
        messages, add_generation_prompt=True, tokenize=True,
        return_dict=True, return_tensors="pt",
    )
    prompt_ids = enc["input_ids"][0].tolist()
    n_prompt = len(prompt_ids)
    if n_prompt + args.max_new_tokens > MASK_LEN:
        raise SystemExit(
            f"prompt ({n_prompt}) + max_new_tokens ({args.max_new_tokens}) "
            f"exceeds mask_len ({MASK_LEN})"
        )
    print(f"  prompt_len = {n_prompt}")

    print(f"\nLoading {args.pte}...")
    t0 = time.time()
    rt = Runtime.get()
    program = rt.load_program(args.pte, verification=Verification.Minimal)
    method = program.load_method("forward")
    print(f"  loaded in {time.time() - t0:.1f}s")

    print("Allocating cache tensors...")
    k_caches, v_caches, cumlen_caches = allocate_cache_tensors()
    cache_mb = sum(t.numel() * t.element_size() for t in k_caches + v_caches) / 1e6
    print(f"  total cache size: {cache_mb:.1f} MB across {len(k_caches)} layers")

    # --- Prompt token-by-token feed ("slow prefill") ---
    print(f"\nFeeding {n_prompt} prompt tokens (token-by-token; no batched prefill in this design)...")
    t_prefill_start = time.time()
    last_logits = None
    for i, tok in enumerate(prompt_ids):
        last_logits, k_caches, v_caches, cumlen_caches = step(
            method, tok, i, k_caches, v_caches, cumlen_caches
        )
    t_prefill = time.time() - t_prefill_start
    print(f"  prompt feed: {t_prefill:.2f}s ({n_prompt} tokens, "
          f"{n_prompt / t_prefill:.2f} tok/s, ttft equivalent)")

    # Next-token prediction from last prompt position's logits
    next_id = int(last_logits[0, -1].argmax())
    generated = [next_id]
    print(f"  first generated token: id={next_id} text={tokenizer.decode([next_id])!r}")

    # --- Decode loop ---
    print(f"\nDecoding up to {args.max_new_tokens - 1} more tokens...")
    t_decode_start = time.time()
    n_decoded = 1
    for step_idx in range(args.max_new_tokens - 1):
        if next_id in EOS_TOKEN_IDS:
            print(f"  hit EOS (id={next_id}) at decode step {step_idx}")
            break
        pos = n_prompt + step_idx  # position of the token we just produced
        last_logits, k_caches, v_caches, cumlen_caches = step(
            method, next_id, pos, k_caches, v_caches, cumlen_caches
        )
        next_id = int(last_logits[0, -1].argmax())
        generated.append(next_id)
        n_decoded += 1
    t_decode = time.time() - t_decode_start

    text = tokenizer.decode(generated, skip_special_tokens=True)
    print(f"\nGenerated ({len(generated)} tokens): {text!r}")
    print(f"\n=== Timing ===")
    print(f"  prompt feed:   {t_prefill*1000:7.0f} ms  ({n_prompt} tok @ {n_prompt/t_prefill:5.2f} tok/s)")
    print(f"  decode:        {t_decode*1000:7.0f} ms  ({n_decoded} tok @ {n_decoded/t_decode:5.2f} tok/s)")
    print(f"  total:         {(t_prefill + t_decode)*1000:7.0f} ms")

    if args.verify:
        compare_len = min(len(generated), len(REFERENCE_IDS))
        match = generated[:compare_len] == REFERENCE_IDS[:compare_len]
        print(f"\n=== Verify ===")
        print(f"  reference text: {REFERENCE_TEXT!r}")
        print(f"  generated text: {text!r}")
        print(f"  reference ids:  {REFERENCE_IDS[:compare_len]}")
        print(f"  generated ids:  {generated[:compare_len]}")
        print(f"  RESULT: {'PASS' if match else 'FAIL'}")
        raise SystemExit(0 if match else 1)


if __name__ == "__main__":
    main()