harshithsaiv commited on
Commit
91c163e
Β·
1 Parent(s): cbebdfd

feat: Initial Triton kernel

Browse files
Files changed (1) hide show
  1. kernel/quant_cache.py +294 -0
kernel/quant_cache.py ADDED
@@ -0,0 +1,294 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Per-Head Mixed-Precision KV Cache
3
+ ----------------------------------
4
+ Quantizes each attention head's K and V tensors
5
+ to either 4-bit or 8-bit based on calibrated sensitivity.
6
+
7
+ Layout per head:
8
+ - quantized data (int8 tensor, packed for 4-bit)
9
+ - scale (float16 scalar)
10
+ - zero_point (float16 scalar)
11
+ """
12
+
13
+ import torch
14
+ import triton
15
+ import triton.language as tl
16
+ import json
17
+ import os
18
+
19
+ # ─── Triton Kernels ───────────────────────────────────────────────
20
+
21
+ @triton.jit
22
+ def quantize_8bit_kernel(
23
+ x_ptr, # input [seq, head_dim]
24
+ q_ptr, # output [seq, head_dim] int8
25
+ scale_ptr, # output scalar float32
26
+ zp_ptr, # output scalar float32
27
+ N, # total elements = seq * head_dim
28
+ BLOCK: tl.constexpr,
29
+ ):
30
+ pid = tl.program_id(0)
31
+ offs = pid * BLOCK + tl.arange(0, BLOCK)
32
+ mask = offs < N
33
+
34
+ x = tl.load(x_ptr + offs, mask=mask, other=0.0).to(tl.float32)
35
+
36
+ # compute scale and zero point from min/max
37
+ x_min = tl.min(x, axis=0)
38
+ x_max = tl.max(x, axis=0)
39
+ scale = (x_max - x_min) / 255.0
40
+ scale = tl.maximum(scale, 1e-8)
41
+ zp = x_min
42
+
43
+ # quantize
44
+ q = tl.extra.libdevice.round((x - zp) / scale)
45
+ q = tl.minimum(tl.maximum(q, 0.0), 255.0)
46
+
47
+ tl.store(q_ptr + offs, q.to(tl.int8), mask=mask)
48
+ # only first thread writes scale/zp
49
+ if pid == 0:
50
+ tl.store(scale_ptr, scale)
51
+ tl.store(zp_ptr, zp)
52
+
53
+
54
+ @triton.jit
55
+ def dequantize_8bit_kernel(
56
+ q_ptr, # input [seq, head_dim] int8
57
+ scale_ptr, # input scalar
58
+ zp_ptr, # input scalar
59
+ out_ptr, # output [seq, head_dim] float16
60
+ N,
61
+ BLOCK: tl.constexpr,
62
+ ):
63
+ pid = tl.program_id(0)
64
+ offs = pid * BLOCK + tl.arange(0, BLOCK)
65
+ mask = offs < N
66
+
67
+ q = tl.load(q_ptr + offs, mask=mask, other=0).to(tl.float32)
68
+ scale = tl.load(scale_ptr).to(tl.float32)
69
+ zp = tl.load(zp_ptr).to(tl.float32)
70
+
71
+ x = q * scale + zp
72
+ tl.store(out_ptr + offs, x.to(tl.float16), mask=mask)
73
+
74
+
75
+ @triton.jit
76
+ def quantize_4bit_kernel(
77
+ x_ptr,
78
+ q_ptr, # output [seq, head_dim] int8 (2 values packed per byte)
79
+ scale_ptr,
80
+ zp_ptr,
81
+ N, # total elements (must be even)
82
+ BLOCK: tl.constexpr,
83
+ ):
84
+ pid = tl.program_id(0)
85
+ # each thread block handles BLOCK output bytes = BLOCK*2 input elements
86
+ offs_out = pid * BLOCK + tl.arange(0, BLOCK)
87
+ offs_in = offs_out * 2
88
+ mask = offs_in + 1 < N
89
+
90
+ x0 = tl.load(x_ptr + offs_in, mask=mask, other=0.0).to(tl.float32)
91
+ x1 = tl.load(x_ptr + offs_in + 1, mask=mask, other=0.0).to(tl.float32)
92
+
93
+ # share scale across both elements
94
+ x_min = tl.minimum(tl.min(x0, axis=0), tl.min(x1, axis=0))
95
+ x_max = tl.maximum(tl.max(x0, axis=0), tl.max(x1, axis=0))
96
+ scale = (x_max - x_min) / 15.0
97
+ scale = tl.maximum(scale, 1e-8)
98
+ zp = x_min
99
+
100
+ q0 = tl.extra.libdevice.round((x0 - zp) / scale)
101
+ q1 = tl.extra.libdevice.round((x1 - zp) / scale)
102
+ q0 = tl.minimum(tl.maximum(q0, 0.0), 15.0).to(tl.int8)
103
+ q1 = tl.minimum(tl.maximum(q1, 0.0), 15.0).to(tl.int8)
104
+
105
+ # pack two 4-bit values into one int8 byte
106
+ packed = q0 | (q1 << 4)
107
+ tl.store(q_ptr + offs_out, packed, mask=mask)
108
+
109
+ if pid == 0:
110
+ tl.store(scale_ptr, scale)
111
+ tl.store(zp_ptr, zp)
112
+
113
+
114
+ @triton.jit
115
+ def dequantize_4bit_kernel(
116
+ q_ptr,
117
+ scale_ptr,
118
+ zp_ptr,
119
+ out_ptr,
120
+ N,
121
+ BLOCK: tl.constexpr,
122
+ ):
123
+ pid = tl.program_id(0)
124
+ offs_out = pid * BLOCK + tl.arange(0, BLOCK)
125
+ offs_in = offs_out * 2
126
+ mask = offs_in + 1 < N
127
+
128
+ packed = tl.load(q_ptr + offs_out, mask=mask, other=0).to(tl.int8)
129
+ scale = tl.load(scale_ptr).to(tl.float32)
130
+ zp = tl.load(zp_ptr).to(tl.float32)
131
+
132
+ # unpack
133
+ q0 = (packed & 0x0F).to(tl.float32)
134
+ q1 = ((packed >> 4) & 0x0F).to(tl.float32)
135
+
136
+ x0 = q0 * scale + zp
137
+ x1 = q1 * scale + zp
138
+
139
+ tl.store(out_ptr + offs_in, x0.to(tl.float16), mask=mask)
140
+ tl.store(out_ptr + offs_in + 1, x1.to(tl.float16), mask=mask)
141
+
142
+
143
+ # ─── Python Wrappers ──────────────────────────────────────────────
144
+
145
+ BLOCK_SIZE = 1024
146
+
147
+ def quantize_head(x: torch.Tensor, bits: int):
148
+ """
149
+ Quantize a single head tensor using Triton kernel.
150
+ x: [seq_len, head_dim] float16
151
+ returns: (q, scale, zp)
152
+ """
153
+ x = x.contiguous()
154
+ N = x.numel()
155
+
156
+ scale = torch.zeros(1, dtype=torch.float32, device=x.device)
157
+ zp = torch.zeros(1, dtype=torch.float32, device=x.device)
158
+
159
+ if bits == 8:
160
+ q = torch.empty(N, dtype=torch.int8, device=x.device)
161
+ grid = (triton.cdiv(N, BLOCK_SIZE),)
162
+ quantize_8bit_kernel[grid](
163
+ x.view(-1), q, scale, zp, N, BLOCK=BLOCK_SIZE
164
+ )
165
+ elif bits == 4:
166
+ assert N % 2 == 0, "head_dim must be even for 4-bit packing"
167
+ q = torch.empty(N // 2, dtype=torch.int8, device=x.device)
168
+ grid = (triton.cdiv(N // 2, BLOCK_SIZE),)
169
+ quantize_4bit_kernel[grid](
170
+ x.view(-1), q, scale, zp, N, BLOCK=BLOCK_SIZE
171
+ )
172
+ else:
173
+ raise ValueError(f"Unsupported bits: {bits}")
174
+
175
+ return q, scale, zp
176
+
177
+
178
+ def dequantize_head(q: torch.Tensor, scale: torch.Tensor,
179
+ zp: torch.Tensor, bits: int,
180
+ original_shape: tuple) -> torch.Tensor:
181
+ """
182
+ Dequantize back to float16.
183
+ Returns tensor of original_shape in float16.
184
+ """
185
+ if bits == 8:
186
+ N = q.numel()
187
+ out = torch.empty(N, dtype=torch.float16, device=q.device)
188
+ grid = (triton.cdiv(N, BLOCK_SIZE),)
189
+ dequantize_8bit_kernel[grid](q, scale, zp, out, N, BLOCK=BLOCK_SIZE)
190
+ elif bits == 4:
191
+ N = q.numel() * 2
192
+ out = torch.empty(N, dtype=torch.float16, device=q.device)
193
+ grid = (triton.cdiv(q.numel(), BLOCK_SIZE),)
194
+ dequantize_4bit_kernel[grid](q, scale, zp, out, N, BLOCK=BLOCK_SIZE)
195
+ else:
196
+ raise ValueError(f"Unsupported bits: {bits}")
197
+
198
+ return out.view(original_shape)
199
+
200
+
201
+ # ─── Per-Layer Cache Manager ──────────────────────────────────────
202
+
203
+ class MixedPrecisionKVCache:
204
+ """
205
+ Stores quantized K and V for all heads in one layer.
206
+ bit_alloc: list of ints, one per head (4 or 8)
207
+ """
208
+
209
+ def __init__(self, bit_alloc: list):
210
+ self.bit_alloc = bit_alloc # [num_heads]
211
+ self.k_cache = [] # list of (q, scale, zp, shape)
212
+ self.v_cache = []
213
+
214
+ def store(self, k: torch.Tensor, v: torch.Tensor):
215
+ """
216
+ k, v: [batch, num_heads, seq, head_dim]
217
+ Quantizes each head independently.
218
+ """
219
+ self.k_cache = []
220
+ self.v_cache = []
221
+ num_heads = k.shape[1]
222
+
223
+ for h in range(num_heads):
224
+ bits = self.bit_alloc[h]
225
+ k_head = k[0, h] # [seq, head_dim]
226
+ v_head = v[0, h]
227
+
228
+ kq, ks, kz = quantize_head(k_head, bits)
229
+ vq, vs, vz = quantize_head(v_head, bits)
230
+
231
+ self.k_cache.append((kq, ks, kz, k_head.shape, bits))
232
+ self.v_cache.append((vq, vs, vz, v_head.shape, bits))
233
+
234
+ def retrieve(self) -> tuple:
235
+ """
236
+ Dequantize all heads and reconstruct full K, V tensors.
237
+ Returns k, v: [1, num_heads, seq, head_dim] float16
238
+ """
239
+ ks, vs = [], []
240
+ for (kq, ksc, kzp, ksh, kb) in self.k_cache:
241
+ ks.append(dequantize_head(kq, ksc, kzp, kb, ksh))
242
+ for (vq, vsc, vzp, vsh, vb) in self.v_cache:
243
+ vs.append(dequantize_head(vq, vsc, vzp, vb, vsh))
244
+
245
+ k = torch.stack(ks, dim=0).unsqueeze(0) # [1, heads, seq, head_dim]
246
+ v = torch.stack(vs, dim=0).unsqueeze(0)
247
+ return k, v
248
+
249
+ def memory_bytes(self) -> int:
250
+ """Estimate memory used by quantized cache."""
251
+ total = 0
252
+ for (q, s, z, shape, bits) in self.k_cache + self.v_cache:
253
+ total += q.numel() + 2 * 4 # data + scale + zp
254
+ return total
255
+
256
+
257
+ # ─── Quick Correctness Test ───────────────────────────────────────
258
+
259
+ if __name__ == "__main__":
260
+ print("Testing MixedPrecisionKVCache...")
261
+
262
+ # simulate one layer: batch=1, heads=8, seq=512, head_dim=128
263
+ torch.manual_seed(42)
264
+ k = torch.randn(1, 8, 512, 128, dtype=torch.float16, device="cuda")
265
+ v = torch.randn(1, 8, 512, 128, dtype=torch.float16, device="cuda")
266
+
267
+ # mixed allocation: alternating 4 and 8 bit
268
+ bit_alloc = [4, 8, 4, 8, 4, 8, 4, 8]
269
+ cache = MixedPrecisionKVCache(bit_alloc)
270
+
271
+ # store
272
+ cache.store(k, v)
273
+
274
+ # retrieve
275
+ k_out, v_out = cache.retrieve()
276
+
277
+ # correctness
278
+ k_err = (k - k_out).abs().mean().item()
279
+ v_err = (v - v_out).abs().mean().item()
280
+ print(f"K reconstruction error: {k_err:.6f}")
281
+ print(f"V reconstruction error: {v_err:.6f}")
282
+
283
+ # memory savings
284
+ fp16_bytes = k.numel() * 2 * 2 # k + v, 2 bytes each
285
+ quant_bytes = cache.memory_bytes()
286
+ print(f"\nFP16 memory: {fp16_bytes/1024:.1f} KB")
287
+ print(f"Quant memory: {quant_bytes/1024:.1f} KB")
288
+ print(f"Compression: {fp16_bytes/quant_bytes:.2f}x")
289
+
290
+ # check errors are reasonable
291
+ assert k_err < 0.1, f"K error too high: {k_err}"
292
+ assert v_err < 0.1, f"V error too high: {v_err}"
293
+ print("\nβœ… All tests passed!")
294
+ EOF