theapemachine commited on
Commit
ae901d9
Β·
verified Β·
1 Parent(s): be59ae6

Upload triton_sparse.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. triton_sparse.py +542 -6
triton_sparse.py CHANGED
@@ -5,11 +5,547 @@ Triton-fused Chunked Sparse Backward Pass.
5
  Replaces the Python for-loop over active chunks with fused Triton kernels:
6
  1. sparse_bwd_dW: grad_W[c*CS:(c+1)*CS, :] = grad_Y[:, c*CS:(c+1)*CS].T @ X for active c
7
  2. sparse_bwd_dX: grad_X += grad_Y[:, c*CS:(c+1)*CS] @ W[c*CS:(c+1)*CS, :] for active c
8
- 3. sparse_bwd_dbias: bias_grad[c*CS:(c+1)*CS] = dY[:, c*CS:(c+1)*CS].sum(dim=0)
9
 
10
- Includes Python-loop baseline, correctness tests, and isolated matmul microbenchmark.
11
-
12
- Usage:
13
- python triton_sparse.py # runs correctness + benchmark
14
  """
15
- # See repo file for full content - uploading from sandbox
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  Replaces the Python for-loop over active chunks with fused Triton kernels:
6
  1. sparse_bwd_dW: grad_W[c*CS:(c+1)*CS, :] = grad_Y[:, c*CS:(c+1)*CS].T @ X for active c
7
  2. sparse_bwd_dX: grad_X += grad_Y[:, c*CS:(c+1)*CS] @ W[c*CS:(c+1)*CS, :] for active c
8
+ 3. sparse_fwd: Y[:, c*CS:(c+1)*CS] = X @ W[c*CS:(c+1)*CS, :].T for active c
9
 
10
+ Benchmark against the Python-loop baseline at various d_model sizes.
 
 
 
11
  """
12
+
13
+ import math
14
+ import os
15
+ import random
16
+ import time
17
+ import urllib.request
18
+ from dataclasses import dataclass
19
+
20
+ import torch
21
+ import torch.nn as nn
22
+ import torch.nn.functional as F
23
+
24
+ import triton
25
+ import triton.language as tl
26
+
27
+ try:
28
+ import tiktoken
29
+ except ImportError:
30
+ raise ImportError("pip install tiktoken")
31
+
32
+ # ═══════════════════════════════════════════════════════════════════
33
+ # TRITON KERNELS
34
+ # ═══════════════════════════════════════════════════════════════════
35
+
36
+ # ── Kernel 1: Sparse dW ──────────────────────────────────────────
37
+ # For each active chunk c:
38
+ # grad_W[c*CS:(c+1)*CS, :] = grad_Y[:, c*CS:(c+1)*CS].T @ X
39
+ #
40
+ # In terms of shapes:
41
+ # grad_Y: (M, d_out), X: (M, d_in), W: (d_out, d_in)
42
+ # For chunk c: rows c*CS..(c+1)*CS of W get grad from cols c*CS..(c+1)*CS of grad_Y
43
+ #
44
+ # Grid: (num_active * ceil(CS/BN), ceil(d_in/BK))
45
+ # pid0 encodes (active_chunk_linear_id, N-block within CS)
46
+ # pid1 encodes K-block within d_in
47
+
48
+ @triton.autotune(
49
+ configs=[
50
+ triton.Config({'BN': 32, 'BK': 64, 'BM': 32}, num_stages=3, num_warps=4),
51
+ triton.Config({'BN': 64, 'BK': 64, 'BM': 32}, num_stages=3, num_warps=4),
52
+ triton.Config({'BN': 64, 'BK': 128, 'BM': 32}, num_stages=3, num_warps=4),
53
+ triton.Config({'BN': 32, 'BK': 128, 'BM': 64}, num_stages=3, num_warps=4),
54
+ triton.Config({'BN': 64, 'BK': 64, 'BM': 64}, num_stages=4, num_warps=4),
55
+ ],
56
+ key=['M', 'd_in', 'CS'],
57
+ )
58
+ @triton.jit
59
+ def _sparse_bwd_dW_kernel(
60
+ X_ptr, dY_ptr, dW_ptr, chunk_ids_ptr,
61
+ M, d_in, d_out, num_active,
62
+ stride_xm, stride_xk,
63
+ stride_dym, stride_dyn,
64
+ stride_dwn, stride_dwk,
65
+ CS: tl.constexpr,
66
+ BN: tl.constexpr,
67
+ BK: tl.constexpr,
68
+ BM: tl.constexpr,
69
+ ):
70
+ """Compute dW tiles for active chunks. Each program writes one [BN, BK] tile."""
71
+ pid0 = tl.program_id(0)
72
+ pid1 = tl.program_id(1)
73
+
74
+ N_BLOCKS_PER_CHUNK = tl.cdiv(CS, BN)
75
+ chunk_linear_id = pid0 // N_BLOCKS_PER_CHUNK
76
+ n_block_id = pid0 % N_BLOCKS_PER_CHUNK
77
+ k_block_id = pid1
78
+
79
+ if chunk_linear_id >= num_active:
80
+ return
81
+
82
+ chunk_idx = tl.load(chunk_ids_ptr + chunk_linear_id)
83
+ chunk_start = chunk_idx * CS
84
+
85
+ # Tile ranges
86
+ rn = n_block_id * BN + tl.arange(0, BN) # rows of dW (= cols of chunk in dY)
87
+ rk = k_block_id * BK + tl.arange(0, BK) # cols of dW (= cols of X)
88
+
89
+ n_abs = chunk_start + rn # absolute column indices in dY
90
+ n_mask = rn < CS
91
+ k_mask = rk < d_in
92
+
93
+ # Accumulate dY[:, chunk_cols].T @ X[:, k_cols] over M-tiles
94
+ acc = tl.zeros((BN, BK), dtype=tl.float32)
95
+
96
+ for m_start in range(0, M, BM):
97
+ rm = m_start + tl.arange(0, BM)
98
+ m_mask = rm < M
99
+
100
+ # Load X tile: (BM, BK)
101
+ x = tl.load(
102
+ X_ptr + rm[:, None] * stride_xm + rk[None, :] * stride_xk,
103
+ mask=m_mask[:, None] & k_mask[None, :],
104
+ other=0.0,
105
+ )
106
+
107
+ # Load dY tile: (BM, BN)
108
+ dy = tl.load(
109
+ dY_ptr + rm[:, None] * stride_dym + n_abs[None, :] * stride_dyn,
110
+ mask=m_mask[:, None] & n_mask[None, :],
111
+ other=0.0,
112
+ )
113
+
114
+ # dY.T @ X -> (BN, BK)
115
+ acc = tl.dot(tl.trans(dy), x, acc=acc)
116
+
117
+ # Write to dW: row = chunk_start + rn, col = rk
118
+ # dW layout: (d_out, d_in)
119
+ dw_ptrs = dW_ptr + n_abs[:, None] * stride_dwn + rk[None, :] * stride_dwk
120
+ tl.store(dw_ptrs, acc.to(dW_ptr.dtype.element_ty), mask=n_mask[:, None] & k_mask[None, :])
121
+
122
+
123
+ def sparse_bwd_dW(X, dY, active_chunks, chunk_size, d_out):
124
+ """Fused Triton kernel for sparse dW computation."""
125
+ M, d_in = X.shape
126
+ num_active = active_chunks.shape[0]
127
+ CS = chunk_size
128
+
129
+ dW = torch.zeros(d_out, d_in, device=X.device, dtype=X.dtype)
130
+ if num_active == 0:
131
+ return dW
132
+
133
+ chunk_ids = active_chunks.to(torch.int32).contiguous()
134
+
135
+ grid = lambda META: (
136
+ num_active * triton.cdiv(CS, META['BN']),
137
+ triton.cdiv(d_in, META['BK']),
138
+ )
139
+
140
+ _sparse_bwd_dW_kernel[grid](
141
+ X, dY, dW, chunk_ids,
142
+ M, d_in, d_out, num_active,
143
+ X.stride(0), X.stride(1),
144
+ dY.stride(0), dY.stride(1),
145
+ dW.stride(0), dW.stride(1),
146
+ CS=CS,
147
+ )
148
+ return dW
149
+
150
+
151
+ # ── Kernel 2: Sparse dX ──────────────────────────────────────────
152
+ # For each active chunk c:
153
+ # grad_X += grad_Y[:, c*CS:(c+1)*CS] @ W[c*CS:(c+1)*CS, :]
154
+ #
155
+ # Grid: (ceil(M/BM), ceil(d_in/BK))
156
+ # Each program accumulates contributions from ALL active chunks.
157
+
158
+ @triton.autotune(
159
+ configs=[
160
+ triton.Config({'BM': 32, 'BK': 64, 'BN': 32}, num_stages=3, num_warps=4),
161
+ triton.Config({'BM': 64, 'BK': 64, 'BN': 32}, num_stages=3, num_warps=4),
162
+ triton.Config({'BM': 64, 'BK': 128, 'BN': 64}, num_stages=3, num_warps=4),
163
+ triton.Config({'BM': 32, 'BK': 128, 'BN': 32}, num_stages=4, num_warps=4),
164
+ ],
165
+ key=['M', 'd_in', 'CS'],
166
+ )
167
+ @triton.jit
168
+ def _sparse_bwd_dX_kernel(
169
+ dY_ptr, W_ptr, dX_ptr, chunk_ids_ptr,
170
+ M, d_in, d_out, num_active,
171
+ stride_dym, stride_dyn,
172
+ stride_wn, stride_wk,
173
+ stride_dxm, stride_dxk,
174
+ CS: tl.constexpr,
175
+ BM: tl.constexpr,
176
+ BK: tl.constexpr,
177
+ BN: tl.constexpr,
178
+ ):
179
+ """Compute dX tiles by summing over active chunks."""
180
+ pid_m = tl.program_id(0)
181
+ pid_k = tl.program_id(1)
182
+
183
+ rm = pid_m * BM + tl.arange(0, BM)
184
+ rk = pid_k * BK + tl.arange(0, BK)
185
+ m_mask = rm < M
186
+ k_mask = rk < d_in
187
+
188
+ acc = tl.zeros((BM, BK), dtype=tl.float32)
189
+
190
+ # Sum over all active chunks
191
+ for i in range(num_active):
192
+ chunk_idx = tl.load(chunk_ids_ptr + i)
193
+ chunk_start = chunk_idx * CS
194
+
195
+ # Tile over BN within the chunk
196
+ for n_start in range(0, CS, BN):
197
+ rn = n_start + tl.arange(0, BN)
198
+ n_abs = chunk_start + rn
199
+ n_mask = rn < CS
200
+
201
+ # Load dY tile: (BM, BN)
202
+ dy = tl.load(
203
+ dY_ptr + rm[:, None] * stride_dym + n_abs[None, :] * stride_dyn,
204
+ mask=m_mask[:, None] & n_mask[None, :],
205
+ other=0.0,
206
+ )
207
+
208
+ # Load W tile: (BN, BK) β€” W[chunk_start+rn, rk]
209
+ w = tl.load(
210
+ W_ptr + n_abs[:, None] * stride_wn + rk[None, :] * stride_wk,
211
+ mask=n_mask[:, None] & k_mask[None, :],
212
+ other=0.0,
213
+ )
214
+
215
+ # dY @ W -> (BM, BK)
216
+ acc = tl.dot(dy, w, acc=acc)
217
+
218
+ # Write dX
219
+ dx_ptrs = dX_ptr + rm[:, None] * stride_dxm + rk[None, :] * stride_dxk
220
+ tl.store(dx_ptrs, acc.to(dX_ptr.dtype.element_ty), mask=m_mask[:, None] & k_mask[None, :])
221
+
222
+
223
+ def sparse_bwd_dX(dY, W, active_chunks, chunk_size, M, d_in):
224
+ """Fused Triton kernel for sparse dX computation."""
225
+ num_active = active_chunks.shape[0]
226
+ CS = chunk_size
227
+
228
+ dX = torch.zeros(M, d_in, device=dY.device, dtype=dY.dtype)
229
+ if num_active == 0:
230
+ return dX
231
+
232
+ chunk_ids = active_chunks.to(torch.int32).contiguous()
233
+
234
+ grid = lambda META: (
235
+ triton.cdiv(M, META['BM']),
236
+ triton.cdiv(d_in, META['BK']),
237
+ )
238
+
239
+ _sparse_bwd_dX_kernel[grid](
240
+ dY, W, dX, chunk_ids,
241
+ M, d_in, dY.shape[1], num_active,
242
+ dY.stride(0), dY.stride(1),
243
+ W.stride(0), W.stride(1),
244
+ dX.stride(0), dX.stride(1),
245
+ CS=CS,
246
+ )
247
+ return dX
248
+
249
+
250
+ # ── Kernel 3: Sparse dBias ────────────────────────────────────────
251
+ # Simple: bias_grad[c*CS:(c+1)*CS] = dY[:, c*CS:(c+1)*CS].sum(dim=0)
252
+
253
+ @triton.jit
254
+ def _sparse_bwd_dbias_kernel(
255
+ dY_ptr, dB_ptr, chunk_ids_ptr,
256
+ M, d_out, num_active,
257
+ stride_dym, stride_dyn,
258
+ CS: tl.constexpr,
259
+ BM: tl.constexpr,
260
+ ):
261
+ pid = tl.program_id(0) # one per (active_chunk, col_within_chunk)
262
+ chunk_linear = pid // CS
263
+ col_in_chunk = pid % CS
264
+
265
+ if chunk_linear >= num_active:
266
+ return
267
+
268
+ chunk_idx = tl.load(chunk_ids_ptr + chunk_linear)
269
+ col_abs = chunk_idx * CS + col_in_chunk
270
+
271
+ acc = 0.0
272
+ for m_start in range(0, M, BM):
273
+ rm = m_start + tl.arange(0, BM)
274
+ m_mask = rm < M
275
+ vals = tl.load(dY_ptr + rm * stride_dym + col_abs * stride_dyn, mask=m_mask, other=0.0)
276
+ acc += tl.sum(vals)
277
+
278
+ tl.store(dB_ptr + col_abs, acc.to(dB_ptr.dtype.element_ty))
279
+
280
+
281
+ def sparse_bwd_dbias(dY, active_chunks, chunk_size, d_out):
282
+ M = dY.shape[0]
283
+ num_active = active_chunks.shape[0]
284
+ dB = torch.zeros(d_out, device=dY.device, dtype=dY.dtype)
285
+ if num_active == 0:
286
+ return dB
287
+ chunk_ids = active_chunks.to(torch.int32).contiguous()
288
+ BM = 128
289
+ grid = (num_active * chunk_size,)
290
+ _sparse_bwd_dbias_kernel[grid](
291
+ dY, dB, chunk_ids,
292
+ M, d_out, num_active,
293
+ dY.stride(0), dY.stride(1),
294
+ CS=chunk_size, BM=BM,
295
+ )
296
+ return dB
297
+
298
+
299
+ # ═════���═════════════════════════════════════════════════════════════
300
+ # AUTOGRAD FUNCTION: Triton-fused
301
+ # ═══════════════════════════════════════════════════════════════════
302
+
303
+ class TritonChunkedSparseLinear(torch.autograd.Function):
304
+ @staticmethod
305
+ def forward(ctx, x, weight, bias, active_chunks, chunk_size, sparse_dx):
306
+ ctx.save_for_backward(x, weight, active_chunks)
307
+ ctx.has_bias = bias is not None
308
+ ctx.sparse_dx = sparse_dx
309
+ ctx.chunk_size = chunk_size
310
+ return F.linear(x, weight, bias)
311
+
312
+ @staticmethod
313
+ def backward(ctx, grad_y):
314
+ x, weight, active_chunks = ctx.saved_tensors
315
+ cs = ctx.chunk_size
316
+ d_out, d_in = weight.shape
317
+
318
+ x_flat = x.reshape(-1, d_in)
319
+ gy_flat = grad_y.reshape(-1, d_out)
320
+ M = x_flat.shape[0]
321
+
322
+ # grad_W via Triton
323
+ grad_w = sparse_bwd_dW(x_flat, gy_flat, active_chunks, cs, d_out)
324
+
325
+ # grad_bias via Triton
326
+ grad_b = sparse_bwd_dbias(gy_flat, active_chunks, cs, d_out) if ctx.has_bias else None
327
+
328
+ # grad_X
329
+ if ctx.sparse_dx:
330
+ grad_x_flat = sparse_bwd_dX(gy_flat, weight, active_chunks, cs, M, d_in)
331
+ else:
332
+ grad_x_flat = gy_flat @ weight # dense dX
333
+
334
+ return grad_x_flat.reshape(x.shape), grad_w, grad_b, None, None, None
335
+
336
+
337
+ # ═══════════════════════════════════════════════════════════════════
338
+ # AUTOGRAD FUNCTION: Python-loop baseline (for comparison)
339
+ # ═══════════════════════════════════════════════════════════════════
340
+
341
+ class PythonLoopSparseLinear(torch.autograd.Function):
342
+ @staticmethod
343
+ def forward(ctx, x, weight, bias, active_chunks, chunk_size, sparse_dx):
344
+ ctx.save_for_backward(x, weight, active_chunks)
345
+ ctx.has_bias = bias is not None
346
+ ctx.sparse_dx = sparse_dx
347
+ ctx.chunk_size = chunk_size
348
+ return F.linear(x, weight, bias)
349
+
350
+ @staticmethod
351
+ def backward(ctx, grad_y):
352
+ x, weight, active_chunks = ctx.saved_tensors
353
+ cs = ctx.chunk_size
354
+ x_flat = x.reshape(-1, x.shape[-1])
355
+ gy_flat = grad_y.reshape(-1, grad_y.shape[-1])
356
+ grad_w = torch.zeros_like(weight)
357
+ grad_b = torch.zeros(weight.shape[0], device=weight.device, dtype=weight.dtype) if ctx.has_bias else None
358
+ if ctx.sparse_dx:
359
+ grad_x_flat = torch.zeros_like(x_flat)
360
+ else:
361
+ grad_x_flat = gy_flat @ weight
362
+
363
+ for c in active_chunks.tolist():
364
+ s, e = c * cs, (c + 1) * cs
365
+ gy_slice = gy_flat[:, s:e]
366
+ grad_w[s:e, :] = gy_slice.t() @ x_flat
367
+ if ctx.has_bias:
368
+ grad_b[s:e] = gy_slice.sum(0)
369
+ if ctx.sparse_dx:
370
+ grad_x_flat += gy_slice @ weight[s:e, :]
371
+
372
+ return grad_x_flat.reshape(x.shape), grad_w, grad_b, None, None, None
373
+
374
+
375
+ # ═══════════════════════════════════════════════════════════════════
376
+ # CORRECTNESS TEST
377
+ # ═══════════════════════════════════════════════════════════════════
378
+
379
+ def test_correctness():
380
+ print("Testing correctness...")
381
+ torch.manual_seed(42)
382
+ device = "cuda"
383
+
384
+ for d_in, d_out, cs in [(512, 2048, 64), (1024, 4096, 64), (256, 1024, 32)]:
385
+ M = 2048 # B*T
386
+ n_chunks = d_out // cs
387
+ n_active = max(1, int(0.1 * n_chunks))
388
+ active = torch.randperm(n_chunks, device=device)[:n_active].sort().values
389
+
390
+ x = torch.randn(M, d_in, device=device, requires_grad=False)
391
+ w = torch.randn(d_out, d_in, device=device, requires_grad=False)
392
+ b = torch.randn(d_out, device=device, requires_grad=False)
393
+ gy = torch.randn(M, d_out, device=device, requires_grad=False)
394
+
395
+ # Reference: Python loop
396
+ ref_dw = torch.zeros_like(w)
397
+ ref_db = torch.zeros_like(b)
398
+ ref_dx = gy @ w # dense dX
399
+ for c in active.tolist():
400
+ s, e = c * cs, (c + 1) * cs
401
+ ref_dw[s:e] = gy[:, s:e].t() @ x
402
+ ref_db[s:e] = gy[:, s:e].sum(0)
403
+
404
+ # Triton
405
+ tri_dw = sparse_bwd_dW(x, gy, active, cs, d_out)
406
+ tri_db = sparse_bwd_dbias(gy, active, cs, d_out)
407
+ tri_dx_sparse = sparse_bwd_dX(gy, w, active, cs, M, d_in)
408
+
409
+ # Compare
410
+ dw_err = (tri_dw - ref_dw).abs().max().item()
411
+ db_err = (tri_db - ref_db).abs().max().item()
412
+
413
+ # For sparse dX, reference
414
+ ref_dx_sparse = torch.zeros_like(x)
415
+ for c in active.tolist():
416
+ s, e = c * cs, (c + 1) * cs
417
+ ref_dx_sparse += gy[:, s:e] @ w[s:e]
418
+ dx_err = (tri_dx_sparse - ref_dx_sparse).abs().max().item()
419
+
420
+ status = "βœ“" if dw_err < 1e-2 and db_err < 1e-2 and dx_err < 1e-2 else "βœ—"
421
+ print(f" {status} d_in={d_in}, d_out={d_out}, cs={cs}: dW_err={dw_err:.6f}, dB_err={db_err:.6f}, dX_err={dx_err:.6f}")
422
+
423
+ print()
424
+
425
+
426
+ # ═══════════════════════════════════════════════════════════════════
427
+ # BENCHMARK
428
+ # ═══════════════════════════════════════════════════════════════════
429
+
430
+ def benchmark():
431
+ print("="*80)
432
+ print("BENCHMARK: Triton Fused vs Python Loop vs Dense")
433
+ print("="*80)
434
+ device = "cuda"
435
+ B, T = 8, 256
436
+ M = B * T
437
+ cs = 64
438
+ af = 0.10
439
+ warmup_iters = 10
440
+ bench_iters = 50
441
+
442
+ print(f"\nM={M} (B={B}, T={T}), chunk_size={cs}, active_frac={af}")
443
+ print(f"{'d_model':>7} | {'d_out':>7} | {'active':>6} | {'Dense':>10} | {'PyLoop':>10} | {'Triton':>10} | {'Tri/Dense':>10} | {'Tri/PyLoop':>10}")
444
+ print("-" * 95)
445
+
446
+ for d_in in [256, 512, 768, 1024, 1536, 2048]:
447
+ d_out = 4 * d_in
448
+ n_chunks = d_out // cs
449
+ n_active = max(1, int(af * n_chunks))
450
+ active = torch.randperm(n_chunks, device=device)[:n_active].sort().values
451
+
452
+ x = torch.randn(M, d_in, device=device)
453
+ w = torch.randn(d_out, d_in, device=device)
454
+ b = torch.randn(d_out, device=device)
455
+ gy = torch.randn(M, d_out, device=device)
456
+
457
+ # Dense backward (dW + dX + dB)
458
+ def dense_bwd():
459
+ dw = gy.t() @ x
460
+ dx = gy @ w
461
+ db = gy.sum(0)
462
+ return dw, dx, db
463
+
464
+ # Python loop backward
465
+ def pyloop_bwd():
466
+ dw = torch.zeros_like(w)
467
+ db = torch.zeros_like(b)
468
+ dx = gy @ w # dense dX
469
+ for c in active.tolist():
470
+ s, e = c * cs, (c + 1) * cs
471
+ dw[s:e] = gy[:, s:e].t() @ x
472
+ db[s:e] = gy[:, s:e].sum(0)
473
+ return dw, dx, db
474
+
475
+ # Triton fused backward
476
+ def triton_bwd():
477
+ dw = sparse_bwd_dW(x, gy, active, cs, d_out)
478
+ dx = gy @ w # dense dX (same as pyloop)
479
+ db = sparse_bwd_dbias(gy, active, cs, d_out)
480
+ return dw, dx, db
481
+
482
+ # Warmup
483
+ for _ in range(warmup_iters):
484
+ dense_bwd(); pyloop_bwd(); triton_bwd()
485
+ torch.cuda.synchronize()
486
+
487
+ # Bench dense
488
+ torch.cuda.synchronize(); t0 = time.perf_counter()
489
+ for _ in range(bench_iters): dense_bwd()
490
+ torch.cuda.synchronize(); dense_time = (time.perf_counter() - t0) / bench_iters
491
+
492
+ # Bench pyloop
493
+ torch.cuda.synchronize(); t0 = time.perf_counter()
494
+ for _ in range(bench_iters): pyloop_bwd()
495
+ torch.cuda.synchronize(); pyloop_time = (time.perf_counter() - t0) / bench_iters
496
+
497
+ # Bench triton
498
+ torch.cuda.synchronize(); t0 = time.perf_counter()
499
+ for _ in range(bench_iters): triton_bwd()
500
+ torch.cuda.synchronize(); triton_time = (time.perf_counter() - t0) / bench_iters
501
+
502
+ tri_vs_dense = dense_time / triton_time
503
+ tri_vs_pyloop = pyloop_time / triton_time
504
+
505
+ print(f"{d_in:>7} | {d_out:>7} | {n_active:>6} | {dense_time*1000:>9.2f}ms | {pyloop_time*1000:>9.2f}ms | {triton_time*1000:>9.2f}ms | {tri_vs_dense:>9.2f}x | {tri_vs_pyloop:>9.2f}x")
506
+
507
+ # Also benchmark with sparse_dX (Triton dX kernel)
508
+ print(f"\n{'='*80}")
509
+ print("With Triton sparse_dX (both dW and dX are sparse):")
510
+ print(f"{'d_model':>7} | {'Dense':>10} | {'Triton_all':>10} | {'Speedup':>10}")
511
+ print("-" * 50)
512
+
513
+ for d_in in [512, 1024, 2048]:
514
+ d_out = 4 * d_in
515
+ n_chunks = d_out // cs
516
+ n_active = max(1, int(af * n_chunks))
517
+ active = torch.randperm(n_chunks, device=device)[:n_active].sort().values
518
+
519
+ x = torch.randn(M, d_in, device=device)
520
+ w = torch.randn(d_out, d_in, device=device)
521
+ gy = torch.randn(M, d_out, device=device)
522
+
523
+ def dense_full():
524
+ dw = gy.t() @ x; dx = gy @ w; return dw, dx
525
+
526
+ def triton_full():
527
+ dw = sparse_bwd_dW(x, gy, active, cs, d_out)
528
+ dx = sparse_bwd_dX(gy, w, active, cs, M, d_in)
529
+ return dw, dx
530
+
531
+ for _ in range(warmup_iters): dense_full(); triton_full()
532
+ torch.cuda.synchronize()
533
+
534
+ torch.cuda.synchronize(); t0 = time.perf_counter()
535
+ for _ in range(bench_iters): dense_full()
536
+ torch.cuda.synchronize(); dt = (time.perf_counter() - t0) / bench_iters
537
+
538
+ torch.cuda.synchronize(); t0 = time.perf_counter()
539
+ for _ in range(bench_iters): triton_full()
540
+ torch.cuda.synchronize(); tt = (time.perf_counter() - t0) / bench_iters
541
+
542
+ print(f"{d_in:>7} | {dt*1000:>9.2f}ms | {tt*1000:>9.2f}ms | {dt/tt:>9.2f}x")
543
+
544
+
545
+ # ═══════════════════════════════════════════════════════════════════
546
+ # MAIN
547
+ # ═══════════════════════════════════════════════════════════════════
548
+
549
+ if __name__ == "__main__":
550
+ test_correctness()
551
+ benchmark()