| """ |
| Test Engine class. Example run: |
| |
| python -m pytest tests/test_engine.py -v |
| """ |
|
|
| import torch |
| from nanochat.engine import KVCache, Engine |
| from dataclasses import dataclass |
|
|
|
|
| |
| |
|
|
| @dataclass |
| class MockConfig: |
| """Minimal config for Engine tests.""" |
| n_kv_head: int = 4 |
| n_head: int = 4 |
| n_embd: int = 64 |
| n_layer: int = 2 |
| sequence_len: int = 128 |
|
|
|
|
| class MockModel: |
| """ |
| Mock model that returns uniform logits over the vocab. |
| This ensures that with temperature > 0, different samples should |
| (with very high probability) produce different tokens. |
| """ |
| def __init__(self, vocab_size=262): |
| self.vocab_size = vocab_size |
| self.config = MockConfig() |
| self._device = "cpu" |
|
|
| def get_device(self): |
| return self._device |
|
|
| def forward(self, ids, kv_cache=None): |
| """Return uniform logits so sampling is spread across vocab.""" |
| B, T = ids.shape |
| |
| if kv_cache is not None: |
| kv_cache.advance(T) |
| |
| logits = torch.zeros(B, T, self.vocab_size) |
| return logits |
|
|
|
|
| class ByteTokenizer: |
| """ |
| Simple byte-level tokenizer for testing. |
| Tokens 0-255 are raw bytes, 256+ are special tokens. |
| """ |
| def __init__(self): |
| |
| self._special_tokens = { |
| "<|python_start|>": 256, |
| "<|python_end|>": 257, |
| "<|output_start|>": 258, |
| "<|output_end|>": 259, |
| "<|assistant_end|>": 260, |
| "<|bos|>": 261, |
| } |
| self._bos = 261 |
|
|
| def encode_special(self, s): |
| return self._special_tokens[s] |
|
|
| def get_bos_token_id(self): |
| return self._bos |
|
|
| def encode(self, s, prepend=None): |
| tokens = list(s.encode("utf-8")) |
| if prepend is not None: |
| tokens = [prepend] + tokens |
| return tokens |
|
|
| def decode(self, tokens): |
| |
| byte_tokens = [t for t in tokens if t < 256] |
| return bytes(byte_tokens).decode("utf-8", errors="replace") |
|
|
| def test_kv_cache_basic(): |
| """Test basic KVCache functionality for FA3.""" |
| batch_size = 2 |
| num_heads = 3 |
| seq_len = 64 |
| head_dim = 5 |
| num_layers = 6 |
|
|
| kv_cache = KVCache( |
| batch_size=batch_size, |
| num_heads=num_heads, |
| seq_len=seq_len, |
| head_dim=head_dim, |
| num_layers=num_layers, |
| device="cpu", |
| dtype=torch.float32, |
| ) |
|
|
| |
| assert kv_cache.get_pos() == 0 |
| assert kv_cache.k_cache.shape == (num_layers, batch_size, seq_len, num_heads, head_dim) |
| assert kv_cache.v_cache.shape == (num_layers, batch_size, seq_len, num_heads, head_dim) |
|
|
| |
| kv_cache.advance(10) |
| assert kv_cache.get_pos() == 10 |
|
|
| kv_cache.advance(5) |
| assert kv_cache.get_pos() == 15 |
|
|
| |
| kv_cache.reset() |
| assert kv_cache.get_pos() == 0 |
|
|
| |
| k_layer0, v_layer0 = kv_cache.get_layer_cache(0) |
| assert k_layer0.shape == (batch_size, seq_len, num_heads, head_dim) |
| assert v_layer0.shape == (batch_size, seq_len, num_heads, head_dim) |
|
|
|
|
| def test_kv_cache_prefill(): |
| """Test KVCache.prefill() copies data correctly.""" |
| batch_size = 1 |
| num_heads = 4 |
| head_dim = 8 |
| num_layers = 2 |
|
|
| |
| src_cache = KVCache( |
| batch_size=batch_size, num_heads=num_heads, seq_len=32, |
| head_dim=head_dim, num_layers=num_layers, device="cpu", dtype=torch.float32, |
| ) |
| |
| src_cache.k_cache[0, 0, :16, :, :] = 1.0 |
| src_cache.v_cache[0, 0, :16, :, :] = 2.0 |
| src_cache.advance(16) |
|
|
| |
| dst_cache = KVCache( |
| batch_size=batch_size, num_heads=num_heads, seq_len=64, |
| head_dim=head_dim, num_layers=num_layers, device="cpu", dtype=torch.float32, |
| ) |
|
|
| |
| dst_cache.prefill(src_cache) |
|
|
| |
| assert dst_cache.get_pos() == 16 |
|
|
| |
| assert (dst_cache.k_cache[0, 0, :16, :, :] == 1.0).all() |
| assert (dst_cache.v_cache[0, 0, :16, :, :] == 2.0).all() |
|
|
|
|
| def test_multi_sample_first_token_diversity(): |
| """ |
| Test that when generating multiple samples, each sample gets an independently |
| sampled first token (not a broadcast of the same token to all rows). |
| |
| Previously, the first token after prefill was sampled once and broadcast to all |
| rows, causing all samples to start identically. The fix expands the prefill logits |
| to num_samples and samples independently for each row. |
| |
| With uniform logits over 262 tokens and 16 samples, the probability that all |
| samples independently pick the same token is (1/262)^15 ≈ 10^-36. So if they're |
| all identical, it indicates tokens are being broadcast instead of independently sampled. |
| """ |
| model = MockModel(vocab_size=262) |
| tokenizer = ByteTokenizer() |
| engine = Engine(model, tokenizer) |
|
|
| |
| prompt_tokens = [261, 72, 101, 108, 108, 111] |
| num_samples = 16 |
|
|
| |
| first_tokens = [] |
| gen = engine.generate( |
| prompt_tokens, |
| num_samples=num_samples, |
| max_tokens=1, |
| temperature=1.0, |
| seed=42, |
| ) |
| for token_column, token_masks in gen: |
| first_tokens = token_column |
|
|
| |
| |
| unique_tokens = set(first_tokens) |
| assert len(unique_tokens) > 1, ( |
| f"All {num_samples} samples got the same first token ({first_tokens[0]}). " |
| f"With uniform logits, this is statistically impossible (~10^-36 probability) " |
| f"unless tokens are being broadcast instead of independently sampled." |
| ) |
|
|
|
|
| def test_seed_reproducibility(): |
| """Same seed must produce identical output.""" |
| model = MockModel() |
| engine = Engine(model, ByteTokenizer()) |
| prompt = [261, 72, 101, 108, 108, 111] |
|
|
| for seed in [1, 42, 123, 999]: |
| r1, _ = engine.generate_batch(prompt, max_tokens=5, seed=seed) |
| r2, _ = engine.generate_batch(prompt, max_tokens=5, seed=seed) |
| r3, _ = engine.generate_batch(prompt, max_tokens=5, seed=seed) |
| assert r1 == r2 == r3, "Same seed must produce identical output for the same prompt." |
|
|
|
|
| def test_temperature_zero_determinism(): |
| """Temperature=0 is deterministic regardless of seed.""" |
| model = MockModel() |
| engine = Engine(model, ByteTokenizer()) |
| prompt = [261, 72, 101, 108, 108, 111] |
|
|
| r1, _ = engine.generate_batch(prompt, temperature=0.0, max_tokens=5, seed=1) |
| r2, _ = engine.generate_batch(prompt, temperature=0.0, max_tokens=5, seed=42) |
| r3, _ = engine.generate_batch(prompt, temperature=0.0, max_tokens=5, seed=123) |
| assert r1 == r2 == r3, "Temperature=0 must result in the same output for the same prompt regardless of seed." |
|
|
|
|
| def test_max_tokens_respected(): |
| """Generation stops at max_tokens limit.""" |
| model = MockModel() |
| engine = Engine(model, ByteTokenizer()) |
| prompt = [261, 72, 101, 108, 108, 111] |
|
|
| for max_tokens in [1, 4, 16, 64]: |
| results, _ = engine.generate_batch(prompt, max_tokens=max_tokens) |
| num_generated_tokens = len(results[0]) - len(prompt) |
| assert num_generated_tokens <= max_tokens, f"Generated {num_generated_tokens} tokens, expected max_tokens={max_tokens} or less." |
|
|
|
|
| def test_num_samples_count(): |
| """num_samples=N produces exactly N sequences.""" |
| model = MockModel() |
| engine = Engine(model, ByteTokenizer()) |
| prompt = [261, 72, 101, 108, 108, 111] |
|
|
| for num_samples in [1, 4, 16, 64]: |
| results, _ = engine.generate_batch(prompt, num_samples=num_samples, max_tokens=3) |
| assert len(results) == num_samples, f"Expected {num_samples} sequences from {num_samples} samples, got {len(results)}" |
|
|
|
|
| def test_different_seeds_introduce_variation_when_temperature_nonzero(): |
| """With temperature > 0, different seeds should introduce sampling variation.""" |
| model = MockModel() |
| engine = Engine(model, ByteTokenizer()) |
| prompt = [261, 72, 101, 108, 108, 111] |
|
|
| outputs = set() |
|
|
| for seed in [1, 42, 123, 999, 1000, 1001, 1002, 1003, 1004, 1005]: |
| results, _ = engine.generate_batch( |
| prompt, |
| temperature=1.0, |
| max_tokens=5, |
| seed=seed, |
| ) |
| outputs.add(tuple(results[0])) |
|
|
| |
| assert len(outputs) > 1, "All seeds produced the same output which is statistically highly improbable." |
|
|