ARBS / testing /attention /test_ring_buffer.py
CLIWorks's picture
Upload folder using huggingface_hub
d8bc908 verified
"""Unit tests for GPURingBuffer and KVLedger."""
import torch
import sys
import os
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))
from arbitor.attention.ring_buffer import GPURingBuffer
from arbitor.attention.kv_ledger import KVLedger
def test_rb_append_wrap():
rb = GPURingBuffer(4)
for i in range(6):
rb.append(i)
assert rb.get_last_n(3).tolist() == [3, 4, 5], f"got {rb.get_last_n(3).tolist()}"
print(" PASS test_rb_append_wrap")
def test_rb_contiguous_no_wrap():
rb = GPURingBuffer(4)
for i in range(3):
rb.append(i)
assert rb.get_last_n(3).tolist() == [0, 1, 2]
print(" PASS test_rb_contiguous_no_wrap")
def test_rb_empty():
rb = GPURingBuffer(4)
assert rb.get_last_n(3).numel() == 0
print(" PASS test_rb_empty")
def test_rb_reset():
rb = GPURingBuffer(4)
for i in range(3):
rb.append(i)
rb.reset()
assert rb.ptr == 0 and rb.size == 0
assert rb.get_last_n(3).numel() == 0
print(" PASS test_rb_reset")
def test_rb_multi_dim():
rb = GPURingBuffer(3, dtype=torch.float32, dim=4)
assert rb.buffer.shape == (3, 4)
for i in range(3):
rb.append(torch.ones(4) * i)
last = rb.get_last_n(2)
assert last.shape == (2, 4), f"shape {last.shape}"
print(" PASS test_rb_multi_dim")
def test_rb_get_all():
rb = GPURingBuffer(4)
for i in range(6):
rb.append(i)
all_vals = rb.get_all()
assert all_vals.tolist() == [2, 3, 4, 5], f"got {all_vals.tolist()}"
print(" PASS test_rb_get_all")
def test_rb_partial():
rb = GPURingBuffer(8)
for i in range(5):
rb.append(i)
assert rb.get_all().tolist() == [0, 1, 2, 3, 4]
assert rb.get_last_n(3).tolist() == [2, 3, 4]
print(" PASS test_rb_partial")
def test_kv_ledger_basic():
kv = KVLedger(256)
for i in range(100):
kv.append(i)
assert len(kv) == 100
assert kv.get_sliding_window(5).tolist() == [95, 96, 97, 98, 99]
print(" PASS test_kv_ledger_basic")
def test_kv_ledger_sliding_window():
kv = KVLedger(32)
for i in range(32):
kv.append(i)
last_5 = kv.get_sliding_window(5)
assert last_5.tolist() == [27, 28, 29, 30, 31], f"got {last_5.tolist()}"
print(" PASS test_kv_ledger_sliding_window")
def test_kv_ledger_sparse():
kv = KVLedger(16)
for i in range(10):
kv.append(i)
sparse = kv.get_sparse(stride=3)
assert len(sparse) == 4, f"len={len(sparse)}"
print(" PASS test_kv_ledger_sparse")
def test_kv_ledger_reset():
kv = KVLedger(8)
for i in range(5):
kv.append(i)
kv.reset()
assert len(kv) == 0
print(" PASS test_kv_ledger_reset")
def test_cuda_device_move():
if not torch.cuda.is_available():
print(" SKIP test_cuda_device_move (no cuda)")
return
rb = GPURingBuffer(4)
rb = rb.to("cuda")
assert rb.buffer.device.type == "cuda"
rb.append(42)
assert rb.get_last_n(1).tolist() == [42]
print(" PASS test_cuda_device_move")
if __name__ == "__main__":
test_rb_append_wrap()
test_rb_contiguous_no_wrap()
test_rb_empty()
test_rb_reset()
test_rb_multi_dim()
test_rb_get_all()
test_rb_partial()
test_kv_ledger_basic()
test_kv_ledger_sliding_window()
test_kv_ledger_sparse()
test_kv_ledger_reset()
test_cuda_device_move()
print("\nAll ring buffer tests PASS")