harshithsaiv commited on
Commit
bc4bbbe
Β·
1 Parent(s): 91c163e

chore: libdevice not present in the current version

Browse files
Files changed (1) hide show
  1. kernel/quant_cache.py +45 -119
kernel/quant_cache.py CHANGED
@@ -1,51 +1,40 @@
 
1
  """
2
  Per-Head Mixed-Precision KV Cache
3
  ----------------------------------
4
  Quantizes each attention head's K and V tensors
5
  to either 4-bit or 8-bit based on calibrated sensitivity.
6
-
7
- Layout per head:
8
- - quantized data (int8 tensor, packed for 4-bit)
9
- - scale (float16 scalar)
10
- - zero_point (float16 scalar)
11
  """
12
 
13
  import torch
14
  import triton
15
  import triton.language as tl
16
- import json
17
- import os
18
 
19
  # ─── Triton Kernels ───────────────────────────────────────────────
20
 
21
  @triton.jit
22
  def quantize_8bit_kernel(
23
- x_ptr, # input [seq, head_dim]
24
- q_ptr, # output [seq, head_dim] int8
25
- scale_ptr, # output scalar float32
26
- zp_ptr, # output scalar float32
27
- N, # total elements = seq * head_dim
28
- BLOCK: tl.constexpr,
29
  ):
30
- pid = tl.program_id(0)
31
  offs = pid * BLOCK + tl.arange(0, BLOCK)
32
  mask = offs < N
33
 
34
  x = tl.load(x_ptr + offs, mask=mask, other=0.0).to(tl.float32)
35
 
36
- # compute scale and zero point from min/max
37
  x_min = tl.min(x, axis=0)
38
  x_max = tl.max(x, axis=0)
39
  scale = (x_max - x_min) / 255.0
40
- scale = tl.maximum(scale, 1e-8)
41
  zp = x_min
42
 
43
- # quantize
44
- q = tl.extra.libdevice.round((x - zp) / scale)
45
- q = tl.minimum(tl.maximum(q, 0.0), 255.0)
 
46
 
47
- tl.store(q_ptr + offs, q.to(tl.int8), mask=mask)
48
- # only first thread writes scale/zp
49
  if pid == 0:
50
  tl.store(scale_ptr, scale)
51
  tl.store(zp_ptr, zp)
@@ -53,18 +42,14 @@ def quantize_8bit_kernel(
53
 
54
  @triton.jit
55
  def dequantize_8bit_kernel(
56
- q_ptr, # input [seq, head_dim] int8
57
- scale_ptr, # input scalar
58
- zp_ptr, # input scalar
59
- out_ptr, # output [seq, head_dim] float16
60
- N,
61
- BLOCK: tl.constexpr,
62
  ):
63
  pid = tl.program_id(0)
64
  offs = pid * BLOCK + tl.arange(0, BLOCK)
65
  mask = offs < N
66
 
67
- q = tl.load(q_ptr + offs, mask=mask, other=0).to(tl.float32)
68
  scale = tl.load(scale_ptr).to(tl.float32)
69
  zp = tl.load(zp_ptr).to(tl.float32)
70
 
@@ -74,15 +59,10 @@ def dequantize_8bit_kernel(
74
 
75
  @triton.jit
76
  def quantize_4bit_kernel(
77
- x_ptr,
78
- q_ptr, # output [seq, head_dim] int8 (2 values packed per byte)
79
- scale_ptr,
80
- zp_ptr,
81
- N, # total elements (must be even)
82
- BLOCK: tl.constexpr,
83
  ):
84
- pid = tl.program_id(0)
85
- # each thread block handles BLOCK output bytes = BLOCK*2 input elements
86
  offs_out = pid * BLOCK + tl.arange(0, BLOCK)
87
  offs_in = offs_out * 2
88
  mask = offs_in + 1 < N
@@ -90,22 +70,19 @@ def quantize_4bit_kernel(
90
  x0 = tl.load(x_ptr + offs_in, mask=mask, other=0.0).to(tl.float32)
91
  x1 = tl.load(x_ptr + offs_in + 1, mask=mask, other=0.0).to(tl.float32)
92
 
93
- # share scale across both elements
94
  x_min = tl.minimum(tl.min(x0, axis=0), tl.min(x1, axis=0))
95
  x_max = tl.maximum(tl.max(x0, axis=0), tl.max(x1, axis=0))
96
  scale = (x_max - x_min) / 15.0
97
- scale = tl.maximum(scale, 1e-8)
98
  zp = x_min
99
 
100
- q0 = tl.extra.libdevice.round((x0 - zp) / scale)
101
- q1 = tl.extra.libdevice.round((x1 - zp) / scale)
102
- q0 = tl.minimum(tl.maximum(q0, 0.0), 15.0).to(tl.int8)
103
- q1 = tl.minimum(tl.maximum(q1, 0.0), 15.0).to(tl.int8)
104
 
105
- # pack two 4-bit values into one int8 byte
106
  packed = q0 | (q1 << 4)
107
  tl.store(q_ptr + offs_out, packed, mask=mask)
108
-
109
  if pid == 0:
110
  tl.store(scale_ptr, scale)
111
  tl.store(zp_ptr, zp)
@@ -113,12 +90,8 @@ def quantize_4bit_kernel(
113
 
114
  @triton.jit
115
  def dequantize_4bit_kernel(
116
- q_ptr,
117
- scale_ptr,
118
- zp_ptr,
119
- out_ptr,
120
- N,
121
- BLOCK: tl.constexpr,
122
  ):
123
  pid = tl.program_id(0)
124
  offs_out = pid * BLOCK + tl.arange(0, BLOCK)
@@ -129,7 +102,6 @@ def dequantize_4bit_kernel(
129
  scale = tl.load(scale_ptr).to(tl.float32)
130
  zp = tl.load(zp_ptr).to(tl.float32)
131
 
132
- # unpack
133
  q0 = (packed & 0x0F).to(tl.float32)
134
  q1 = ((packed >> 4) & 0x0F).to(tl.float32)
135
 
@@ -145,26 +117,20 @@ def dequantize_4bit_kernel(
145
  BLOCK_SIZE = 1024
146
 
147
  def quantize_head(x: torch.Tensor, bits: int):
148
- """
149
- Quantize a single head tensor using Triton kernel.
150
- x: [seq_len, head_dim] float16
151
- returns: (q, scale, zp)
152
- """
153
- x = x.contiguous()
154
  N = x.numel()
155
-
156
  scale = torch.zeros(1, dtype=torch.float32, device=x.device)
157
  zp = torch.zeros(1, dtype=torch.float32, device=x.device)
158
 
159
  if bits == 8:
160
- q = torch.empty(N, dtype=torch.int8, device=x.device)
161
  grid = (triton.cdiv(N, BLOCK_SIZE),)
162
  quantize_8bit_kernel[grid](
163
  x.view(-1), q, scale, zp, N, BLOCK=BLOCK_SIZE
164
  )
165
  elif bits == 4:
166
- assert N % 2 == 0, "head_dim must be even for 4-bit packing"
167
- q = torch.empty(N // 2, dtype=torch.int8, device=x.device)
168
  grid = (triton.cdiv(N // 2, BLOCK_SIZE),)
169
  quantize_4bit_kernel[grid](
170
  x.view(-1), q, scale, zp, N, BLOCK=BLOCK_SIZE
@@ -175,20 +141,14 @@ def quantize_head(x: torch.Tensor, bits: int):
175
  return q, scale, zp
176
 
177
 
178
- def dequantize_head(q: torch.Tensor, scale: torch.Tensor,
179
- zp: torch.Tensor, bits: int,
180
- original_shape: tuple) -> torch.Tensor:
181
- """
182
- Dequantize back to float16.
183
- Returns tensor of original_shape in float16.
184
- """
185
  if bits == 8:
186
- N = q.numel()
187
  out = torch.empty(N, dtype=torch.float16, device=q.device)
188
  grid = (triton.cdiv(N, BLOCK_SIZE),)
189
  dequantize_8bit_kernel[grid](q, scale, zp, out, N, BLOCK=BLOCK_SIZE)
190
  elif bits == 4:
191
- N = q.numel() * 2
192
  out = torch.empty(N, dtype=torch.float16, device=q.device)
193
  grid = (triton.cdiv(q.numel(), BLOCK_SIZE),)
194
  dequantize_4bit_kernel[grid](q, scale, zp, out, N, BLOCK=BLOCK_SIZE)
@@ -198,96 +158,62 @@ def dequantize_head(q: torch.Tensor, scale: torch.Tensor,
198
  return out.view(original_shape)
199
 
200
 
201
- # ─── Per-Layer Cache Manager ──────────────────────────────────────
202
 
203
  class MixedPrecisionKVCache:
204
- """
205
- Stores quantized K and V for all heads in one layer.
206
- bit_alloc: list of ints, one per head (4 or 8)
207
- """
208
-
209
  def __init__(self, bit_alloc: list):
210
- self.bit_alloc = bit_alloc # [num_heads]
211
- self.k_cache = [] # list of (q, scale, zp, shape)
212
  self.v_cache = []
213
 
214
  def store(self, k: torch.Tensor, v: torch.Tensor):
215
- """
216
- k, v: [batch, num_heads, seq, head_dim]
217
- Quantizes each head independently.
218
- """
219
  self.k_cache = []
220
  self.v_cache = []
221
- num_heads = k.shape[1]
222
-
223
- for h in range(num_heads):
224
  bits = self.bit_alloc[h]
225
- k_head = k[0, h] # [seq, head_dim]
226
  v_head = v[0, h]
227
-
228
  kq, ks, kz = quantize_head(k_head, bits)
229
  vq, vs, vz = quantize_head(v_head, bits)
230
-
231
  self.k_cache.append((kq, ks, kz, k_head.shape, bits))
232
  self.v_cache.append((vq, vs, vz, v_head.shape, bits))
233
 
234
- def retrieve(self) -> tuple:
235
- """
236
- Dequantize all heads and reconstruct full K, V tensors.
237
- Returns k, v: [1, num_heads, seq, head_dim] float16
238
- """
239
- ks, vs = [], []
240
- for (kq, ksc, kzp, ksh, kb) in self.k_cache:
241
- ks.append(dequantize_head(kq, ksc, kzp, kb, ksh))
242
- for (vq, vsc, vzp, vsh, vb) in self.v_cache:
243
- vs.append(dequantize_head(vq, vsc, vzp, vb, vsh))
244
-
245
- k = torch.stack(ks, dim=0).unsqueeze(0) # [1, heads, seq, head_dim]
246
- v = torch.stack(vs, dim=0).unsqueeze(0)
247
  return k, v
248
 
249
- def memory_bytes(self) -> int:
250
- """Estimate memory used by quantized cache."""
251
- total = 0
252
- for (q, s, z, shape, bits) in self.k_cache + self.v_cache:
253
- total += q.numel() + 2 * 4 # data + scale + zp
254
- return total
255
 
256
 
257
- # ─── Quick Correctness Test ───────────────────────────────────────
258
 
259
  if __name__ == "__main__":
260
  print("Testing MixedPrecisionKVCache...")
261
-
262
- # simulate one layer: batch=1, heads=8, seq=512, head_dim=128
263
  torch.manual_seed(42)
264
  k = torch.randn(1, 8, 512, 128, dtype=torch.float16, device="cuda")
265
  v = torch.randn(1, 8, 512, 128, dtype=torch.float16, device="cuda")
266
 
267
- # mixed allocation: alternating 4 and 8 bit
268
  bit_alloc = [4, 8, 4, 8, 4, 8, 4, 8]
269
- cache = MixedPrecisionKVCache(bit_alloc)
270
 
271
- # store
272
  cache.store(k, v)
273
-
274
- # retrieve
275
  k_out, v_out = cache.retrieve()
276
 
277
- # correctness
278
  k_err = (k - k_out).abs().mean().item()
279
  v_err = (v - v_out).abs().mean().item()
280
  print(f"K reconstruction error: {k_err:.6f}")
281
  print(f"V reconstruction error: {v_err:.6f}")
282
 
283
- # memory savings
284
- fp16_bytes = k.numel() * 2 * 2 # k + v, 2 bytes each
285
  quant_bytes = cache.memory_bytes()
286
  print(f"\nFP16 memory: {fp16_bytes/1024:.1f} KB")
287
  print(f"Quant memory: {quant_bytes/1024:.1f} KB")
288
  print(f"Compression: {fp16_bytes/quant_bytes:.2f}x")
289
 
290
- # check errors are reasonable
291
  assert k_err < 0.1, f"K error too high: {k_err}"
292
  assert v_err < 0.1, f"V error too high: {v_err}"
293
  print("\nβœ… All tests passed!")
 
1
+ cat > ~/kv-hack/kernel/quant_cache.py << 'EOF'
2
  """
3
  Per-Head Mixed-Precision KV Cache
4
  ----------------------------------
5
  Quantizes each attention head's K and V tensors
6
  to either 4-bit or 8-bit based on calibrated sensitivity.
 
 
 
 
 
7
  """
8
 
9
  import torch
10
  import triton
11
  import triton.language as tl
 
 
12
 
13
  # ─── Triton Kernels ───────────────────────────────────────────────
14
 
15
  @triton.jit
16
  def quantize_8bit_kernel(
17
+ x_ptr, q_ptr, scale_ptr, zp_ptr,
18
+ N, BLOCK: tl.constexpr,
 
 
 
 
19
  ):
20
+ pid = tl.program_id(0)
21
  offs = pid * BLOCK + tl.arange(0, BLOCK)
22
  mask = offs < N
23
 
24
  x = tl.load(x_ptr + offs, mask=mask, other=0.0).to(tl.float32)
25
 
 
26
  x_min = tl.min(x, axis=0)
27
  x_max = tl.max(x, axis=0)
28
  scale = (x_max - x_min) / 255.0
29
+ scale = tl.where(scale < 1e-8, 1e-8, scale)
30
  zp = x_min
31
 
32
+ # round by adding 0.5 then casting
33
+ q = ((x - zp) / scale + 0.5).to(tl.int32)
34
+ q = tl.where(q < 0, 0, q)
35
+ q = tl.where(q > 255, 255, q)
36
 
37
+ tl.store(q_ptr + offs, q.to(tl.int8), mask=mask)
 
38
  if pid == 0:
39
  tl.store(scale_ptr, scale)
40
  tl.store(zp_ptr, zp)
 
42
 
43
  @triton.jit
44
  def dequantize_8bit_kernel(
45
+ q_ptr, scale_ptr, zp_ptr, out_ptr,
46
+ N, BLOCK: tl.constexpr,
 
 
 
 
47
  ):
48
  pid = tl.program_id(0)
49
  offs = pid * BLOCK + tl.arange(0, BLOCK)
50
  mask = offs < N
51
 
52
+ q = tl.load(q_ptr + offs, mask=mask, other=0).to(tl.float32)
53
  scale = tl.load(scale_ptr).to(tl.float32)
54
  zp = tl.load(zp_ptr).to(tl.float32)
55
 
 
59
 
60
  @triton.jit
61
  def quantize_4bit_kernel(
62
+ x_ptr, q_ptr, scale_ptr, zp_ptr,
63
+ N, BLOCK: tl.constexpr,
 
 
 
 
64
  ):
65
+ pid = tl.program_id(0)
 
66
  offs_out = pid * BLOCK + tl.arange(0, BLOCK)
67
  offs_in = offs_out * 2
68
  mask = offs_in + 1 < N
 
70
  x0 = tl.load(x_ptr + offs_in, mask=mask, other=0.0).to(tl.float32)
71
  x1 = tl.load(x_ptr + offs_in + 1, mask=mask, other=0.0).to(tl.float32)
72
 
 
73
  x_min = tl.minimum(tl.min(x0, axis=0), tl.min(x1, axis=0))
74
  x_max = tl.maximum(tl.max(x0, axis=0), tl.max(x1, axis=0))
75
  scale = (x_max - x_min) / 15.0
76
+ scale = tl.where(scale < 1e-8, 1e-8, scale)
77
  zp = x_min
78
 
79
+ q0 = ((x0 - zp) / scale + 0.5).to(tl.int32)
80
+ q1 = ((x1 - zp) / scale + 0.5).to(tl.int32)
81
+ q0 = tl.where(q0 < 0, 0, tl.where(q0 > 15, 15, q0)).to(tl.int8)
82
+ q1 = tl.where(q1 < 0, 0, tl.where(q1 > 15, 15, q1)).to(tl.int8)
83
 
 
84
  packed = q0 | (q1 << 4)
85
  tl.store(q_ptr + offs_out, packed, mask=mask)
 
86
  if pid == 0:
87
  tl.store(scale_ptr, scale)
88
  tl.store(zp_ptr, zp)
 
90
 
91
  @triton.jit
92
  def dequantize_4bit_kernel(
93
+ q_ptr, scale_ptr, zp_ptr, out_ptr,
94
+ N, BLOCK: tl.constexpr,
 
 
 
 
95
  ):
96
  pid = tl.program_id(0)
97
  offs_out = pid * BLOCK + tl.arange(0, BLOCK)
 
102
  scale = tl.load(scale_ptr).to(tl.float32)
103
  zp = tl.load(zp_ptr).to(tl.float32)
104
 
 
105
  q0 = (packed & 0x0F).to(tl.float32)
106
  q1 = ((packed >> 4) & 0x0F).to(tl.float32)
107
 
 
117
  BLOCK_SIZE = 1024
118
 
119
  def quantize_head(x: torch.Tensor, bits: int):
120
+ x = x.contiguous().to(torch.float16)
 
 
 
 
 
121
  N = x.numel()
 
122
  scale = torch.zeros(1, dtype=torch.float32, device=x.device)
123
  zp = torch.zeros(1, dtype=torch.float32, device=x.device)
124
 
125
  if bits == 8:
126
+ q = torch.empty(N, dtype=torch.int8, device=x.device)
127
  grid = (triton.cdiv(N, BLOCK_SIZE),)
128
  quantize_8bit_kernel[grid](
129
  x.view(-1), q, scale, zp, N, BLOCK=BLOCK_SIZE
130
  )
131
  elif bits == 4:
132
+ assert N % 2 == 0
133
+ q = torch.empty(N // 2, dtype=torch.int8, device=x.device)
134
  grid = (triton.cdiv(N // 2, BLOCK_SIZE),)
135
  quantize_4bit_kernel[grid](
136
  x.view(-1), q, scale, zp, N, BLOCK=BLOCK_SIZE
 
141
  return q, scale, zp
142
 
143
 
144
+ def dequantize_head(q, scale, zp, bits, original_shape):
 
 
 
 
 
 
145
  if bits == 8:
146
+ N = q.numel()
147
  out = torch.empty(N, dtype=torch.float16, device=q.device)
148
  grid = (triton.cdiv(N, BLOCK_SIZE),)
149
  dequantize_8bit_kernel[grid](q, scale, zp, out, N, BLOCK=BLOCK_SIZE)
150
  elif bits == 4:
151
+ N = q.numel() * 2
152
  out = torch.empty(N, dtype=torch.float16, device=q.device)
153
  grid = (triton.cdiv(q.numel(), BLOCK_SIZE),)
154
  dequantize_4bit_kernel[grid](q, scale, zp, out, N, BLOCK=BLOCK_SIZE)
 
158
  return out.view(original_shape)
159
 
160
 
161
+ # ─── Cache Manager ────────────────────────────────────────────────
162
 
163
  class MixedPrecisionKVCache:
 
 
 
 
 
164
  def __init__(self, bit_alloc: list):
165
+ self.bit_alloc = bit_alloc
166
+ self.k_cache = []
167
  self.v_cache = []
168
 
169
  def store(self, k: torch.Tensor, v: torch.Tensor):
 
 
 
 
170
  self.k_cache = []
171
  self.v_cache = []
172
+ for h in range(k.shape[1]):
 
 
173
  bits = self.bit_alloc[h]
174
+ k_head = k[0, h]
175
  v_head = v[0, h]
 
176
  kq, ks, kz = quantize_head(k_head, bits)
177
  vq, vs, vz = quantize_head(v_head, bits)
 
178
  self.k_cache.append((kq, ks, kz, k_head.shape, bits))
179
  self.v_cache.append((vq, vs, vz, v_head.shape, bits))
180
 
181
+ def retrieve(self):
182
+ ks = [dequantize_head(q,s,z,b,sh) for q,s,z,sh,b in self.k_cache]
183
+ vs = [dequantize_head(q,s,z,b,sh) for q,s,z,sh,b in self.v_cache]
184
+ k = torch.stack(ks, dim=0).unsqueeze(0)
185
+ v = torch.stack(vs, dim=0).unsqueeze(0)
 
 
 
 
 
 
 
 
186
  return k, v
187
 
188
+ def memory_bytes(self):
189
+ return sum(q.numel() + 8 for q,s,z,sh,b in self.k_cache + self.v_cache)
 
 
 
 
190
 
191
 
192
+ # ─── Test ─────────────────────────────────────────────────────────
193
 
194
  if __name__ == "__main__":
195
  print("Testing MixedPrecisionKVCache...")
 
 
196
  torch.manual_seed(42)
197
  k = torch.randn(1, 8, 512, 128, dtype=torch.float16, device="cuda")
198
  v = torch.randn(1, 8, 512, 128, dtype=torch.float16, device="cuda")
199
 
 
200
  bit_alloc = [4, 8, 4, 8, 4, 8, 4, 8]
201
+ cache = MixedPrecisionKVCache(bit_alloc)
202
 
 
203
  cache.store(k, v)
 
 
204
  k_out, v_out = cache.retrieve()
205
 
 
206
  k_err = (k - k_out).abs().mean().item()
207
  v_err = (v - v_out).abs().mean().item()
208
  print(f"K reconstruction error: {k_err:.6f}")
209
  print(f"V reconstruction error: {v_err:.6f}")
210
 
211
+ fp16_bytes = k.numel() * 2 * 2
 
212
  quant_bytes = cache.memory_bytes()
213
  print(f"\nFP16 memory: {fp16_bytes/1024:.1f} KB")
214
  print(f"Quant memory: {quant_bytes/1024:.1f} KB")
215
  print(f"Compression: {fp16_bytes/quant_bytes:.2f}x")
216
 
 
217
  assert k_err < 0.1, f"K error too high: {k_err}"
218
  assert v_err < 0.1, f"V error too high: {v_err}"
219
  print("\nβœ… All tests passed!")