turboquant / scripts /test_cache.py
vivekvar's picture
Upload folder using huggingface_hub
d4ec3e8 verified
"""Test TurboQuantCache integration with the HF Transformers cache API."""
import sys
sys.path.insert(0, "/home/azureuser/turboquant")
import torch
from types import SimpleNamespace
from turboquant.cache import TurboQuantCache, TurboQuantLayer
def test_cache_basic():
"""Test TurboQuantCache with mock model config, simulating Qwen2.5-32B."""
print("=" * 60)
print("TEST: TurboQuantCache basic operations")
print("=" * 60)
device = "cuda" if torch.cuda.is_available() else "cpu"
# Mock Qwen2.5-32B config (just the fields we need)
config = SimpleNamespace(
num_hidden_layers=4, # Use 4 layers for testing (not 64)
hidden_size=5120,
num_attention_heads=40,
)
# Mock get_text_config for compatibility
config.get_text_config = lambda decoder=True: config
cache = TurboQuantCache(config, nbits=4, residual_length=4, device=device)
print(f" Created cache with {len(cache.layers)} layers")
batch, heads, head_dim = 1, 8, 128
# Simulate prefill: 16 tokens at once
for layer_idx in range(4):
k = torch.randn(batch, heads, 16, head_dim, device=device, dtype=torch.bfloat16)
v = torch.randn(batch, heads, 16, head_dim, device=device, dtype=torch.bfloat16)
k_out, v_out = cache.update(k, v, layer_idx)
print(f" Layer {layer_idx} prefill: input ({k.shape}) → output ({k_out.shape})")
assert k_out.shape == (batch, heads, 16, head_dim)
assert k_out.dtype == torch.bfloat16
# Simulate decode: 1 token at a time, 8 steps
for step in range(8):
for layer_idx in range(4):
k = torch.randn(batch, heads, 1, head_dim, device=device, dtype=torch.bfloat16)
v = torch.randn(batch, heads, 1, head_dim, device=device, dtype=torch.bfloat16)
k_out, v_out = cache.update(k, v, layer_idx)
expected_len = 16 + step + 1
assert k_out.shape == (batch, heads, expected_len, head_dim), \
f"Expected seq_len={expected_len}, got {k_out.shape[-2]}"
assert k_out.dtype == torch.bfloat16
if step == 0 or step == 7:
print(f" Decode step {step}: seq_len={k_out.shape[-2]}")
# Check sequence length
seq_len = cache.get_seq_length(0)
print(f" Final seq_length: {seq_len}")
print("\n PASS: Cache operations correct\n")
def test_cache_memory():
"""Compare memory usage: DynamicCache vs TurboQuantCache."""
from transformers.cache_utils import DynamicCache
print("=" * 60)
print("TEST: Memory comparison vs DynamicCache")
print("=" * 60)
device = "cuda"
if not torch.cuda.is_available():
print(" SKIP: No CUDA available")
return
config = SimpleNamespace(
num_hidden_layers=64,
hidden_size=5120,
num_attention_heads=40,
)
config.get_text_config = lambda decoder=True: config
batch, heads, head_dim = 1, 8, 128
seq_len = 4096
# --- DynamicCache (BF16 baseline) ---
torch.cuda.reset_peak_memory_stats()
torch.cuda.empty_cache()
mem_before = torch.cuda.memory_allocated()
dyn_cache = DynamicCache()
for layer_idx in range(64):
k = torch.randn(batch, heads, seq_len, head_dim, device=device, dtype=torch.bfloat16)
v = torch.randn(batch, heads, seq_len, head_dim, device=device, dtype=torch.bfloat16)
dyn_cache.update(k, v, layer_idx)
mem_dynamic = torch.cuda.memory_allocated() - mem_before
del dyn_cache
torch.cuda.empty_cache()
# --- TurboQuantCache (4-bit) ---
torch.cuda.reset_peak_memory_stats()
mem_before = torch.cuda.memory_allocated()
tq_cache = TurboQuantCache(config, nbits=4, residual_length=1, device=device)
for layer_idx in range(64):
k = torch.randn(batch, heads, seq_len, head_dim, device=device, dtype=torch.bfloat16)
v = torch.randn(batch, heads, seq_len, head_dim, device=device, dtype=torch.bfloat16)
tq_cache.update(k, v, layer_idx)
mem_turboquant = torch.cuda.memory_allocated() - mem_before
del tq_cache
torch.cuda.empty_cache()
ratio = mem_dynamic / max(mem_turboquant, 1)
print(f" Seq length: {seq_len}")
print(f" Layers: 64")
print(f" DynamicCache: {mem_dynamic / 1024**2:.1f} MB")
print(f" TurboQuantCache: {mem_turboquant / 1024**2:.1f} MB")
print(f" Compression: {ratio:.2f}x")
print(f"\n PASS: Memory comparison done\n")
if __name__ == "__main__":
test_cache_basic()
test_cache_memory()
print("=" * 60)
print("ALL CACHE TESTS PASSED")
print("=" * 60)