|
|
""" |
|
|
Test Engine class. Example run: |
|
|
|
|
|
python -m pytest tests/test_engine.py -v |
|
|
""" |
|
|
|
|
|
import torch |
|
|
from nanochat.engine import KVCache |
|
|
|
|
|
def test_kv_cache_resize(): |
|
|
""" |
|
|
The KV cache was not resized correctly, more information here: |
|
|
https://github.com/karpathy/nanochat/pull/186 |
|
|
This test reproduces the issue and will be merged alongside the fix. |
|
|
""" |
|
|
|
|
|
batch_size = 2 |
|
|
num_heads = 3 |
|
|
seq_len = 4 |
|
|
head_dim = 5 |
|
|
num_layers = 6 |
|
|
|
|
|
kv_cache = KVCache( |
|
|
batch_size=batch_size, |
|
|
num_heads=num_heads, |
|
|
seq_len=seq_len, |
|
|
head_dim=head_dim, |
|
|
num_layers=num_layers |
|
|
) |
|
|
|
|
|
|
|
|
def insert_token(token_idx): |
|
|
for layer_idx in range(num_layers): |
|
|
k = torch.full((batch_size, num_heads, 1, head_dim), fill_value=float(token_idx), dtype=torch.float32) |
|
|
v = torch.full((batch_size, num_heads, 1, head_dim), fill_value=float(token_idx * 100), dtype=torch.float32) |
|
|
kv_cache.insert_kv(layer_idx, k, v) |
|
|
|
|
|
|
|
|
for i in range(4): |
|
|
insert_token(i) |
|
|
|
|
|
|
|
|
original_cache = kv_cache.kv_cache.clone() |
|
|
original_seq_len = original_cache.shape[4] |
|
|
|
|
|
|
|
|
insert_token(4) |
|
|
|
|
|
new_seq_len = kv_cache.kv_cache.shape[4] |
|
|
assert new_seq_len > original_seq_len, f"Cache did not resize: original seq_len={original_seq_len}, new seq_len={new_seq_len}" |
|
|
|
|
|
|
|
|
for layer_idx in range(num_layers): |
|
|
for token_idx in range(4): |
|
|
|
|
|
expected_k = float(token_idx) |
|
|
expected_v = float(token_idx * 100) |
|
|
actual_k = kv_cache.kv_cache[layer_idx, 0, :, :, token_idx, :] |
|
|
actual_v = kv_cache.kv_cache[layer_idx, 1, :, :, token_idx, :] |
|
|
assert (actual_k == expected_k).all(), f"Layer {layer_idx}, token {token_idx}: key corrupted, expected {expected_k}" |
|
|
assert (actual_v == expected_v).all(), f"Layer {layer_idx}, token {token_idx}: value corrupted, expected {expected_v}" |
|
|
|
|
|
original_k = original_cache[layer_idx, 0, :, :, token_idx, :] |
|
|
original_v = original_cache[layer_idx, 1, :, :, token_idx, :] |
|
|
assert (actual_k == original_k).all(), f"Layer {layer_idx}, token {token_idx}: key doesn't match original" |
|
|
assert (actual_v == original_v).all(), f"Layer {layer_idx}, token {token_idx}: value doesn't match original" |
|
|
|