harshithsaiv commited on
Commit
e23db09
·
1 Parent(s): 6a962fc

feat: Testing pure pytorch

Browse files
Files changed (1) hide show
  1. kernel/quant_cache.py +61 -156
kernel/quant_cache.py CHANGED
@@ -1,171 +1,51 @@
1
  """
2
  Per-Head Mixed-Precision KV Cache
3
- ----------------------------------
4
- Quantizes each attention head's K and V tensors
5
- to either 4-bit or 8-bit based on calibrated sensitivity.
6
  """
7
 
8
  import torch
9
- import triton
10
- import triton.language as tl
11
-
12
- # ─── Triton Kernels ───────────────────────────────────────────────
13
-
14
- @triton.jit
15
- def quantize_8bit_kernel(
16
- x_ptr, q_ptr, scale_ptr, zp_ptr,
17
- N, BLOCK: tl.constexpr,
18
- ):
19
- pid = tl.program_id(0)
20
- offs = pid * BLOCK + tl.arange(0, BLOCK)
21
- mask = offs < N
22
-
23
- x = tl.load(x_ptr + offs, mask=mask, other=0.0).to(tl.float32)
24
-
25
- x_min = tl.min(x, axis=0)
26
- x_max = tl.max(x, axis=0)
27
- scale = (x_max - x_min) / 255.0
28
- scale = tl.where(scale < 1e-8, 1e-8, scale)
29
- zp = x_min
30
-
31
- # round by adding 0.5 then casting
32
- q = ((x - zp) / scale + 0.5).to(tl.int32)
33
- q = tl.where(q < 0, 0, q)
34
- q = tl.where(q > 255, 255, q)
35
-
36
- tl.store(q_ptr + offs, q.to(tl.int8), mask=mask)
37
- if pid == 0:
38
- tl.store(scale_ptr, scale)
39
- tl.store(zp_ptr, zp)
40
-
41
-
42
- @triton.jit
43
- def dequantize_8bit_kernel(
44
- q_ptr, scale_ptr, zp_ptr, out_ptr,
45
- N, BLOCK: tl.constexpr,
46
- ):
47
- pid = tl.program_id(0)
48
- offs = pid * BLOCK + tl.arange(0, BLOCK)
49
- mask = offs < N
50
-
51
- q = tl.load(q_ptr + offs, mask=mask, other=0).to(tl.float32)
52
- scale = tl.load(scale_ptr).to(tl.float32)
53
- zp = tl.load(zp_ptr).to(tl.float32)
54
-
55
- x = q * scale + zp
56
- tl.store(out_ptr + offs, x.to(tl.float16), mask=mask)
57
-
58
-
59
- @triton.jit
60
- def quantize_4bit_kernel(
61
- x_ptr, q_ptr, scale_ptr, zp_ptr,
62
- N, BLOCK: tl.constexpr,
63
- ):
64
- pid = tl.program_id(0)
65
- offs_out = pid * BLOCK + tl.arange(0, BLOCK)
66
- offs_in = offs_out * 2
67
- mask = offs_in + 1 < N
68
-
69
- x0 = tl.load(x_ptr + offs_in, mask=mask, other=0.0).to(tl.float32)
70
- x1 = tl.load(x_ptr + offs_in + 1, mask=mask, other=0.0).to(tl.float32)
71
-
72
- x_min = tl.minimum(tl.min(x0, axis=0), tl.min(x1, axis=0))
73
- x_max = tl.maximum(tl.max(x0, axis=0), tl.max(x1, axis=0))
74
- scale = (x_max - x_min) / 15.0
75
- scale = tl.where(scale < 1e-8, 1e-8, scale)
76
- zp = x_min
77
 
78
- q0 = ((x0 - zp) / scale + 0.5).to(tl.int32)
79
- q1 = ((x1 - zp) / scale + 0.5).to(tl.int32)
80
- q0 = tl.where(q0 < 0, 0, tl.where(q0 > 15, 15, q0)).to(tl.int8)
81
- q1 = tl.where(q1 < 0, 0, tl.where(q1 > 15, 15, q1)).to(tl.int8)
82
-
83
- packed = q0 | (q1 << 4)
84
- tl.store(q_ptr + offs_out, packed, mask=mask)
85
- if pid == 0:
86
- tl.store(scale_ptr, scale)
87
- tl.store(zp_ptr, zp)
88
-
89
-
90
- @triton.jit
91
- def dequantize_4bit_kernel(
92
- q_ptr, scale_ptr, zp_ptr, out_ptr,
93
- N, BLOCK: tl.constexpr,
94
- ):
95
- pid = tl.program_id(0)
96
- offs_out = pid * BLOCK + tl.arange(0, BLOCK)
97
- offs_in = offs_out * 2
98
- mask = offs_in + 1 < N
99
-
100
- packed = tl.load(q_ptr + offs_out, mask=mask, other=0).to(tl.int8)
101
- scale = tl.load(scale_ptr).to(tl.float32)
102
- zp = tl.load(zp_ptr).to(tl.float32)
103
-
104
- q0 = (packed & 0x0F).to(tl.float32)
105
- q1 = ((packed >> 4) & 0x0F).to(tl.float32)
106
-
107
- x0 = q0 * scale + zp
108
- x1 = q1 * scale + zp
109
-
110
- tl.store(out_ptr + offs_in, x0.to(tl.float16), mask=mask)
111
- tl.store(out_ptr + offs_in + 1, x1.to(tl.float16), mask=mask)
112
-
113
-
114
- # ─── Python Wrappers ──────────────────────────────────────────────
115
-
116
- BLOCK_SIZE = 1024
117
 
118
  def quantize_head(x: torch.Tensor, bits: int):
119
- x = x.contiguous().to(torch.float16)
120
- N = x.numel()
121
- scale = torch.zeros(1, dtype=torch.float32, device=x.device)
122
- zp = torch.zeros(1, dtype=torch.float32, device=x.device)
123
-
124
  if bits == 8:
125
- q = torch.empty(N, dtype=torch.int8, device=x.device)
126
- grid = (triton.cdiv(N, BLOCK_SIZE),)
127
- quantize_8bit_kernel[grid](
128
- x.view(-1), q, scale, zp, N, BLOCK=BLOCK_SIZE
129
- )
130
  elif bits == 4:
131
- assert N % 2 == 0
132
- q = torch.empty(N // 2, dtype=torch.int8, device=x.device)
133
- grid = (triton.cdiv(N // 2, BLOCK_SIZE),)
134
- quantize_4bit_kernel[grid](
135
- x.view(-1), q, scale, zp, N, BLOCK=BLOCK_SIZE
136
- )
137
  else:
138
  raise ValueError(f"Unsupported bits: {bits}")
139
-
 
 
 
 
140
  return q, scale, zp
141
 
142
 
143
  def dequantize_head(q, scale, zp, bits, original_shape):
144
- if bits == 8:
145
- N = q.numel()
146
- out = torch.empty(N, dtype=torch.float16, device=q.device)
147
- grid = (triton.cdiv(N, BLOCK_SIZE),)
148
- dequantize_8bit_kernel[grid](q, scale, zp, out, N, BLOCK=BLOCK_SIZE)
149
- elif bits == 4:
150
- N = q.numel() * 2
151
- out = torch.empty(N, dtype=torch.float16, device=q.device)
152
- grid = (triton.cdiv(q.numel(), BLOCK_SIZE),)
153
- dequantize_4bit_kernel[grid](q, scale, zp, out, N, BLOCK=BLOCK_SIZE)
154
- else:
155
- raise ValueError(f"Unsupported bits: {bits}")
156
 
157
- return out.view(original_shape)
158
-
159
-
160
- # ─── Cache Manager ────────────────────────────────────────────────
161
 
162
  class MixedPrecisionKVCache:
 
 
 
 
163
  def __init__(self, bit_alloc: list):
164
  self.bit_alloc = bit_alloc
165
  self.k_cache = []
166
  self.v_cache = []
167
 
168
  def store(self, k: torch.Tensor, v: torch.Tensor):
 
169
  self.k_cache = []
170
  self.v_cache = []
171
  for h in range(k.shape[1]):
@@ -178,6 +58,7 @@ class MixedPrecisionKVCache:
178
  self.v_cache.append((vq, vs, vz, v_head.shape, bits))
179
 
180
  def retrieve(self):
 
181
  ks = [dequantize_head(q,s,z,b,sh) for q,s,z,sh,b in self.k_cache]
182
  vs = [dequantize_head(q,s,z,b,sh) for q,s,z,sh,b in self.v_cache]
183
  k = torch.stack(ks, dim=0).unsqueeze(0)
@@ -185,10 +66,15 @@ class MixedPrecisionKVCache:
185
  return k, v
186
 
187
  def memory_bytes(self):
188
- return sum(q.numel() + 8 for q,s,z,sh,b in self.k_cache + self.v_cache)
189
-
 
 
 
 
 
 
190
 
191
- # ─── Test ─────────────────────────────────────────────────────────
192
 
193
  if __name__ == "__main__":
194
  print("Testing MixedPrecisionKVCache...")
@@ -196,24 +82,43 @@ if __name__ == "__main__":
196
  k = torch.randn(1, 8, 512, 128, dtype=torch.float16, device="cuda")
197
  v = torch.randn(1, 8, 512, 128, dtype=torch.float16, device="cuda")
198
 
199
- bit_alloc = [4, 8, 4, 8, 4, 8, 4, 8]
200
- cache = MixedPrecisionKVCache(bit_alloc)
201
-
 
202
  cache.store(k, v)
203
  k_out, v_out = cache.retrieve()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
204
 
 
 
 
 
 
 
205
  k_err = (k - k_out).abs().mean().item()
206
  v_err = (v - v_out).abs().mean().item()
207
- print(f"K reconstruction error: {k_err:.6f}")
208
- print(f"V reconstruction error: {v_err:.6f}")
209
 
210
  fp16_bytes = k.numel() * 2 * 2
211
  quant_bytes = cache.memory_bytes()
212
  print(f"\nFP16 memory: {fp16_bytes/1024:.1f} KB")
213
  print(f"Quant memory: {quant_bytes/1024:.1f} KB")
214
  print(f"Compression: {fp16_bytes/quant_bytes:.2f}x")
215
-
216
- assert k_err < 0.1, f"K error too high: {k_err}"
217
- assert v_err < 0.1, f"V error too high: {v_err}"
218
- print("\n✅ All tests passed!")
219
- EOF
 
1
  """
2
  Per-Head Mixed-Precision KV Cache
3
+ Using PyTorch for correctness, Triton optimization later.
 
 
4
  """
5
 
6
  import torch
7
+ import json
8
+ import os
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
  def quantize_head(x: torch.Tensor, bits: int):
12
+ """Quantize [seq, head_dim] tensor to given bits."""
13
+ x = x.float()
14
+ x_min = x.min()
15
+ x_max = x.max()
16
+
17
  if bits == 8:
18
+ qmax = 255.0
 
 
 
 
19
  elif bits == 4:
20
+ qmax = 15.0
 
 
 
 
 
21
  else:
22
  raise ValueError(f"Unsupported bits: {bits}")
23
+
24
+ scale = (x_max - x_min).clamp(min=1e-8) / qmax
25
+ zp = x_min
26
+
27
+ q = ((x - zp) / scale).round().clamp(0, qmax).to(torch.uint8)
28
  return q, scale, zp
29
 
30
 
31
  def dequantize_head(q, scale, zp, bits, original_shape):
32
+ """Dequantize back to float16."""
33
+ x = q.float() * scale + zp
34
+ return x.to(torch.float16).view(original_shape)
 
 
 
 
 
 
 
 
 
35
 
 
 
 
 
36
 
37
  class MixedPrecisionKVCache:
38
+ """
39
+ Stores quantized K and V for all heads in one layer.
40
+ bit_alloc: list of ints, one per head (4 or 8)
41
+ """
42
  def __init__(self, bit_alloc: list):
43
  self.bit_alloc = bit_alloc
44
  self.k_cache = []
45
  self.v_cache = []
46
 
47
  def store(self, k: torch.Tensor, v: torch.Tensor):
48
+ """k, v: [batch, num_heads, seq, head_dim]"""
49
  self.k_cache = []
50
  self.v_cache = []
51
  for h in range(k.shape[1]):
 
58
  self.v_cache.append((vq, vs, vz, v_head.shape, bits))
59
 
60
  def retrieve(self):
61
+ """Dequantize all heads, return [1, heads, seq, head_dim] float16."""
62
  ks = [dequantize_head(q,s,z,b,sh) for q,s,z,sh,b in self.k_cache]
63
  vs = [dequantize_head(q,s,z,b,sh) for q,s,z,sh,b in self.v_cache]
64
  k = torch.stack(ks, dim=0).unsqueeze(0)
 
66
  return k, v
67
 
68
  def memory_bytes(self):
69
+ total = 0
70
+ for (q, s, z, sh, bits) in self.k_cache + self.v_cache:
71
+ if bits == 4:
72
+ # 4-bit: 2 values per byte
73
+ total += q.numel() // 2 + 8
74
+ else:
75
+ total += q.numel() + 8
76
+ return total
77
 
 
78
 
79
  if __name__ == "__main__":
80
  print("Testing MixedPrecisionKVCache...")
 
82
  k = torch.randn(1, 8, 512, 128, dtype=torch.float16, device="cuda")
83
  v = torch.randn(1, 8, 512, 128, dtype=torch.float16, device="cuda")
84
 
85
+ # test 8-bit only first
86
+ print("\n--- 8-bit only ---")
87
+ bit_alloc = [8] * 8
88
+ cache = MixedPrecisionKVCache(bit_alloc)
89
  cache.store(k, v)
90
  k_out, v_out = cache.retrieve()
91
+ k_err = (k - k_out).abs().mean().item()
92
+ v_err = (v - v_out).abs().mean().item()
93
+ print(f"K error: {k_err:.6f} V error: {v_err:.6f}")
94
+ assert k_err < 0.01, f"8-bit K error too high: {k_err}"
95
+ print("✅ 8-bit passed!")
96
+
97
+ # test 4-bit only
98
+ print("\n--- 4-bit only ---")
99
+ bit_alloc = [4] * 8
100
+ cache = MixedPrecisionKVCache(bit_alloc)
101
+ cache.store(k, v)
102
+ k_out, v_out = cache.retrieve()
103
+ k_err = (k - k_out).abs().mean().item()
104
+ v_err = (v - v_out).abs().mean().item()
105
+ print(f"K error: {k_err:.6f} V error: {v_err:.6f}")
106
+ assert k_err < 0.1, f"4-bit K error too high: {k_err}"
107
+ print("✅ 4-bit passed!")
108
 
109
+ # test mixed
110
+ print("\n--- Mixed 4/8-bit ---")
111
+ bit_alloc = [4, 8, 4, 8, 4, 8, 4, 8]
112
+ cache = MixedPrecisionKVCache(bit_alloc)
113
+ cache.store(k, v)
114
+ k_out, v_out = cache.retrieve()
115
  k_err = (k - k_out).abs().mean().item()
116
  v_err = (v - v_out).abs().mean().item()
117
+ print(f"K error: {k_err:.6f} V error: {v_err:.6f}")
 
118
 
119
  fp16_bytes = k.numel() * 2 * 2
120
  quant_bytes = cache.memory_bytes()
121
  print(f"\nFP16 memory: {fp16_bytes/1024:.1f} KB")
122
  print(f"Quant memory: {quant_bytes/1024:.1f} KB")
123
  print(f"Compression: {fp16_bytes/quant_bytes:.2f}x")
124
+ print("\n✅ All tests passed!")