File size: 4,387 Bytes
91c163e
 
e23db09
91c163e
 
 
e23db09
 
91c163e
 
 
e23db09
 
 
 
 
91c163e
e23db09
91c163e
e23db09
91c163e
 
e23db09
 
 
 
 
91c163e
 
 
bc4bbbe
e23db09
 
 
91c163e
 
 
e23db09
 
 
 
91c163e
bc4bbbe
 
91c163e
 
 
e23db09
91c163e
 
bc4bbbe
91c163e
bc4bbbe
91c163e
 
 
 
 
 
bc4bbbe
e23db09
bc4bbbe
 
 
 
91c163e
 
bc4bbbe
5e16ca3
e23db09
 
 
5e16ca3
e23db09
 
 
91c163e
5e16ca3
 
 
 
 
 
 
91c163e
 
 
 
 
 
 
e23db09
 
 
 
91c163e
 
e23db09
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4b2bdf2
e23db09
91c163e
e23db09
 
 
 
 
 
91c163e
 
e23db09
91c163e
bc4bbbe
91c163e
 
 
 
e23db09
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
"""
Per-Head Mixed-Precision KV Cache
Using PyTorch for correctness, Triton optimization later.
"""

import torch
import json
import os


def quantize_head(x: torch.Tensor, bits: int):
    """Quantize [seq, head_dim] tensor to given bits."""
    x = x.float()
    x_min = x.min()
    x_max = x.max()
    
    if bits == 8:
        qmax = 255.0
    elif bits == 4:
        qmax = 15.0
    else:
        raise ValueError(f"Unsupported bits: {bits}")
    
    scale = (x_max - x_min).clamp(min=1e-8) / qmax
    zp    = x_min
    
    q = ((x - zp) / scale).round().clamp(0, qmax).to(torch.uint8)
    return q, scale, zp


def dequantize_head(q, scale, zp, bits, original_shape):
    """Dequantize back to float16."""
    x = q.float() * scale + zp
    return x.to(torch.float16).view(original_shape)


class MixedPrecisionKVCache:
    """
    Stores quantized K and V for all heads in one layer.
    bit_alloc: list of ints, one per head (4 or 8)
    """
    def __init__(self, bit_alloc: list):
        self.bit_alloc = bit_alloc
        self.k_cache   = []
        self.v_cache   = []

    def store(self, k: torch.Tensor, v: torch.Tensor):
        """k, v: [batch, num_heads, seq, head_dim]"""
        self.k_cache = []
        self.v_cache = []
        for h in range(k.shape[1]):
            bits   = self.bit_alloc[h]
            k_head = k[0, h]
            v_head = v[0, h]
            kq, ks, kz = quantize_head(k_head, bits)
            vq, vs, vz = quantize_head(v_head, bits)
            self.k_cache.append((kq, ks, kz, k_head.shape, bits))
            self.v_cache.append((vq, vs, vz, v_head.shape, bits))

    def retrieve(self):
        """Dequantize all heads, return [1, heads, seq, head_dim] float16."""
        ks = [dequantize_head(q,s,z,b,sh) for q,s,z,sh,b in self.k_cache]
        vs = [dequantize_head(q,s,z,b,sh) for q,s,z,sh,b in self.v_cache]
        k  = torch.stack(ks, dim=0).unsqueeze(0)
        v  = torch.stack(vs, dim=0).unsqueeze(0)
        return k, v

    def memory_bytes(self):
        """Theoretical memory — 4-bit stored as uint8 (not truly packed)."""
        total = 0
        for (q, s, z, sh, bits) in self.k_cache + self.v_cache:
            if bits == 4:
                total += q.numel() // 2 + 8  # theoretical packed size
            else:
                total += q.numel() + 8
        return total

    def real_gpu_bytes(self):
        """Actual GPU memory used by tensors."""
        total = 0
        for (q, s, z, sh, bits) in self.k_cache + self.v_cache:
            total += q.numel() + 8  # actual bytes on GPU (uint8 for 4-bit = wasteful)
        return total


if __name__ == "__main__":
    print("Testing MixedPrecisionKVCache...")
    torch.manual_seed(42)
    k = torch.randn(1, 8, 512, 128, dtype=torch.float16, device="cuda")
    v = torch.randn(1, 8, 512, 128, dtype=torch.float16, device="cuda")

    # test 8-bit only first
    print("\n--- 8-bit only ---")
    bit_alloc = [8] * 8
    cache = MixedPrecisionKVCache(bit_alloc)
    cache.store(k, v)
    k_out, v_out = cache.retrieve()
    k_err = (k - k_out).abs().mean().item()
    v_err = (v - v_out).abs().mean().item()
    print(f"K error: {k_err:.6f}  V error: {v_err:.6f}")
    assert k_err < 0.01, f"8-bit K error too high: {k_err}"
    print("✅ 8-bit passed!")

    # test 4-bit only
    print("\n--- 4-bit only ---")
    bit_alloc = [4] * 8
    cache = MixedPrecisionKVCache(bit_alloc)
    cache.store(k, v)
    k_out, v_out = cache.retrieve()
    k_err = (k - k_out).abs().mean().item()
    v_err = (v - v_out).abs().mean().item()
    print(f"K error: {k_err:.6f}  V error: {v_err:.6f}")
    assert k_err < 0.2, f"4-bit K error too high: {k_err}"  # 4-bit on random data ~0.14, real KV data is lower
    print("✅ 4-bit passed!")

    # test mixed
    print("\n--- Mixed 4/8-bit ---")
    bit_alloc = [4, 8, 4, 8, 4, 8, 4, 8]
    cache = MixedPrecisionKVCache(bit_alloc)
    cache.store(k, v)
    k_out, v_out = cache.retrieve()
    k_err = (k - k_out).abs().mean().item()
    v_err = (v - v_out).abs().mean().item()
    print(f"K error: {k_err:.6f}  V error: {v_err:.6f}")

    fp16_bytes  = k.numel() * 2 * 2
    quant_bytes = cache.memory_bytes()
    print(f"\nFP16 memory:  {fp16_bytes/1024:.1f} KB")
    print(f"Quant memory: {quant_bytes/1024:.1f} KB")
    print(f"Compression:  {fp16_bytes/quant_bytes:.2f}x")
    print("\n✅ All tests passed!")