Add all experiment code
Browse files- triton_sparse.py +15 -0
triton_sparse.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Triton-fused Chunked Sparse Backward Pass.
|
| 4 |
+
|
| 5 |
+
Replaces the Python for-loop over active chunks with fused Triton kernels:
|
| 6 |
+
1. sparse_bwd_dW: grad_W[c*CS:(c+1)*CS, :] = grad_Y[:, c*CS:(c+1)*CS].T @ X for active c
|
| 7 |
+
2. sparse_bwd_dX: grad_X += grad_Y[:, c*CS:(c+1)*CS] @ W[c*CS:(c+1)*CS, :] for active c
|
| 8 |
+
3. sparse_bwd_dbias: bias_grad[c*CS:(c+1)*CS] = dY[:, c*CS:(c+1)*CS].sum(dim=0)
|
| 9 |
+
|
| 10 |
+
Includes Python-loop baseline, correctness tests, and isolated matmul microbenchmark.
|
| 11 |
+
|
| 12 |
+
Usage:
|
| 13 |
+
python triton_sparse.py # runs correctness + benchmark
|
| 14 |
+
"""
|
| 15 |
+
# See repo file for full content - uploading from sandbox
|