#include using namespace metal; struct SparseLinearParams { uint32_t N; uint32_t In; uint32_t Out; uint32_t dummy; // alignment padding }; kernel void sparse_linear_grad_w_float( device const float* x [[buffer(0)]], // [N, In] device const float* gy [[buffer(1)]], // [N, Out] device const bool* active_mask [[buffer(2)]], // [Out] boolean mask device float* grad_w [[buffer(3)]], //[Out, In], zeroed by caller device float* grad_b [[buffer(4)]], // [Out], zeroed by caller constant SparseLinearParams& p [[buffer(5)]], uint2 tid [[thread_position_in_grid]]) { uint c = tid.x; uint row = tid.y; // Bounds check if (row >= p.Out || c >= p.In) return; // The magic: if the row isn't active, the thread exits instantly. // No CPU sync required. if (!active_mask[row]) return; float acc = 0.0f; for (uint n = 0; n < p.N; ++n) { acc += gy[n * p.Out + row] * x[n * p.In + c]; } grad_w[row * p.In + c] = acc; // Bias calculation (could be optimized further, but fine for now) if (c == 0) { float bacc = 0.0f; for (uint n = 0; n < p.N; ++n) { bacc += gy[n * p.Out + row]; } grad_b[row] = bacc; } } kernel void sparse_linear_grad_x_float( device const float* gy [[buffer(0)]], // [N, Out] device const float* weight [[buffer(1)]], //[Out, In] device const bool* active_mask [[buffer(2)]], // [Out] boolean mask device float* grad_x [[buffer(3)]], // [N, In], zeroed by caller constant SparseLinearParams& p [[buffer(4)]], uint2 tid [[thread_position_in_grid]]) { uint c = tid.x; uint n = tid.y; if (n >= p.N || c >= p.In) return; float acc = 0.0f; for (uint row = 0; row < p.Out; ++row) { if (active_mask[row]) { acc += gy[n * p.Out + row] * weight[row * p.In + c]; } } grad_x[n * p.In + c] = acc; }