File size: 4,617 Bytes
d4ec3e8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
"""Test TurboQuantCache integration with the HF Transformers cache API."""

import sys
sys.path.insert(0, "/home/azureuser/turboquant")

import torch
from types import SimpleNamespace
from turboquant.cache import TurboQuantCache, TurboQuantLayer


def test_cache_basic():
    """Test TurboQuantCache with mock model config, simulating Qwen2.5-32B."""
    print("=" * 60)
    print("TEST: TurboQuantCache basic operations")
    print("=" * 60)

    device = "cuda" if torch.cuda.is_available() else "cpu"

    # Mock Qwen2.5-32B config (just the fields we need)
    config = SimpleNamespace(
        num_hidden_layers=4,  # Use 4 layers for testing (not 64)
        hidden_size=5120,
        num_attention_heads=40,
    )
    # Mock get_text_config for compatibility
    config.get_text_config = lambda decoder=True: config

    cache = TurboQuantCache(config, nbits=4, residual_length=4, device=device)
    print(f"  Created cache with {len(cache.layers)} layers")

    batch, heads, head_dim = 1, 8, 128

    # Simulate prefill: 16 tokens at once
    for layer_idx in range(4):
        k = torch.randn(batch, heads, 16, head_dim, device=device, dtype=torch.bfloat16)
        v = torch.randn(batch, heads, 16, head_dim, device=device, dtype=torch.bfloat16)

        k_out, v_out = cache.update(k, v, layer_idx)
        print(f"  Layer {layer_idx} prefill: input ({k.shape}) → output ({k_out.shape})")
        assert k_out.shape == (batch, heads, 16, head_dim)
        assert k_out.dtype == torch.bfloat16

    # Simulate decode: 1 token at a time, 8 steps
    for step in range(8):
        for layer_idx in range(4):
            k = torch.randn(batch, heads, 1, head_dim, device=device, dtype=torch.bfloat16)
            v = torch.randn(batch, heads, 1, head_dim, device=device, dtype=torch.bfloat16)

            k_out, v_out = cache.update(k, v, layer_idx)

            expected_len = 16 + step + 1
            assert k_out.shape == (batch, heads, expected_len, head_dim), \
                f"Expected seq_len={expected_len}, got {k_out.shape[-2]}"
            assert k_out.dtype == torch.bfloat16

        if step == 0 or step == 7:
            print(f"  Decode step {step}: seq_len={k_out.shape[-2]}")

    # Check sequence length
    seq_len = cache.get_seq_length(0)
    print(f"  Final seq_length: {seq_len}")

    print("\n  PASS: Cache operations correct\n")


def test_cache_memory():
    """Compare memory usage: DynamicCache vs TurboQuantCache."""
    from transformers.cache_utils import DynamicCache

    print("=" * 60)
    print("TEST: Memory comparison vs DynamicCache")
    print("=" * 60)

    device = "cuda"
    if not torch.cuda.is_available():
        print("  SKIP: No CUDA available")
        return

    config = SimpleNamespace(
        num_hidden_layers=64,
        hidden_size=5120,
        num_attention_heads=40,
    )
    config.get_text_config = lambda decoder=True: config

    batch, heads, head_dim = 1, 8, 128
    seq_len = 4096

    # --- DynamicCache (BF16 baseline) ---
    torch.cuda.reset_peak_memory_stats()
    torch.cuda.empty_cache()
    mem_before = torch.cuda.memory_allocated()

    dyn_cache = DynamicCache()
    for layer_idx in range(64):
        k = torch.randn(batch, heads, seq_len, head_dim, device=device, dtype=torch.bfloat16)
        v = torch.randn(batch, heads, seq_len, head_dim, device=device, dtype=torch.bfloat16)
        dyn_cache.update(k, v, layer_idx)

    mem_dynamic = torch.cuda.memory_allocated() - mem_before
    del dyn_cache
    torch.cuda.empty_cache()

    # --- TurboQuantCache (4-bit) ---
    torch.cuda.reset_peak_memory_stats()
    mem_before = torch.cuda.memory_allocated()

    tq_cache = TurboQuantCache(config, nbits=4, residual_length=1, device=device)
    for layer_idx in range(64):
        k = torch.randn(batch, heads, seq_len, head_dim, device=device, dtype=torch.bfloat16)
        v = torch.randn(batch, heads, seq_len, head_dim, device=device, dtype=torch.bfloat16)
        tq_cache.update(k, v, layer_idx)

    mem_turboquant = torch.cuda.memory_allocated() - mem_before
    del tq_cache
    torch.cuda.empty_cache()

    ratio = mem_dynamic / max(mem_turboquant, 1)
    print(f"  Seq length:     {seq_len}")
    print(f"  Layers:         64")
    print(f"  DynamicCache:   {mem_dynamic / 1024**2:.1f} MB")
    print(f"  TurboQuantCache: {mem_turboquant / 1024**2:.1f} MB")
    print(f"  Compression:    {ratio:.2f}x")
    print(f"\n  PASS: Memory comparison done\n")


if __name__ == "__main__":
    test_cache_basic()
    test_cache_memory()
    print("=" * 60)
    print("ALL CACHE TESTS PASSED")
    print("=" * 60)