Add sparse transformer v19 with Triton-backed KNN scheduler and various backward modes. Includes utilities for synthetic data generation and model training. Implements chunked sparse updates and integrates with existing sparse linear layers.
bc1b8eb | using namespace metal; | |
| struct SparseLinearParams { | |
| uint32_t N; | |
| uint32_t In; | |
| uint32_t Out; | |
| uint32_t dummy; // alignment padding | |
| }; | |
| kernel void sparse_linear_grad_w_float( | |
| device const float* x [[buffer(0)]], // [N, In] | |
| device const float* gy [[buffer(1)]], // [N, Out] | |
| device const bool* active_mask [[buffer(2)]], // [Out] boolean mask | |
| device float* grad_w [[buffer(3)]], //[Out, In], zeroed by caller | |
| device float* grad_b [[buffer(4)]], // [Out], zeroed by caller | |
| constant SparseLinearParams& p [[buffer(5)]], | |
| uint2 tid [[thread_position_in_grid]]) | |
| { | |
| uint c = tid.x; | |
| uint row = tid.y; | |
| // Bounds check | |
| if (row >= p.Out || c >= p.In) return; | |
| // The magic: if the row isn't active, the thread exits instantly. | |
| // No CPU sync required. | |
| if (!active_mask[row]) return; | |
| float acc = 0.0f; | |
| for (uint n = 0; n < p.N; ++n) { | |
| acc += gy[n * p.Out + row] * x[n * p.In + c]; | |
| } | |
| grad_w[row * p.In + c] = acc; | |
| // Bias calculation (could be optimized further, but fine for now) | |
| if (c == 0) { | |
| float bacc = 0.0f; | |
| for (uint n = 0; n < p.N; ++n) { | |
| bacc += gy[n * p.Out + row]; | |
| } | |
| grad_b[row] = bacc; | |
| } | |
| } | |
| kernel void sparse_linear_grad_x_float( | |
| device const float* gy [[buffer(0)]], // [N, Out] | |
| device const float* weight [[buffer(1)]], //[Out, In] | |
| device const bool* active_mask [[buffer(2)]], // [Out] boolean mask | |
| device float* grad_x [[buffer(3)]], // [N, In], zeroed by caller | |
| constant SparseLinearParams& p [[buffer(4)]], | |
| uint2 tid [[thread_position_in_grid]]) | |
| { | |
| uint c = tid.x; | |
| uint n = tid.y; | |
| if (n >= p.N || c >= p.In) return; | |
| float acc = 0.0f; | |
| for (uint row = 0; row < p.Out; ++row) { | |
| if (active_mask[row]) { | |
| acc += gy[n * p.Out + row] * weight[row * p.In + c]; | |
| } | |
| } | |
| grad_x[n * p.In + c] = acc; | |
| } |