theapemachine's picture
Add sparse transformer v19 with Triton-backed KNN scheduler and various backward modes. Includes utilities for synthetic data generation and model training. Implements chunked sparse updates and integrates with existing sparse linear layers.
bc1b8eb
#include <metal_stdlib>
using namespace metal;
struct SparseLinearParams {
uint32_t N;
uint32_t In;
uint32_t Out;
uint32_t dummy; // alignment padding
};
kernel void sparse_linear_grad_w_float(
device const float* x [[buffer(0)]], // [N, In]
device const float* gy [[buffer(1)]], // [N, Out]
device const bool* active_mask [[buffer(2)]], // [Out] boolean mask
device float* grad_w [[buffer(3)]], //[Out, In], zeroed by caller
device float* grad_b [[buffer(4)]], // [Out], zeroed by caller
constant SparseLinearParams& p [[buffer(5)]],
uint2 tid [[thread_position_in_grid]])
{
uint c = tid.x;
uint row = tid.y;
// Bounds check
if (row >= p.Out || c >= p.In) return;
// The magic: if the row isn't active, the thread exits instantly.
// No CPU sync required.
if (!active_mask[row]) return;
float acc = 0.0f;
for (uint n = 0; n < p.N; ++n) {
acc += gy[n * p.Out + row] * x[n * p.In + c];
}
grad_w[row * p.In + c] = acc;
// Bias calculation (could be optimized further, but fine for now)
if (c == 0) {
float bacc = 0.0f;
for (uint n = 0; n < p.N; ++n) {
bacc += gy[n * p.Out + row];
}
grad_b[row] = bacc;
}
}
kernel void sparse_linear_grad_x_float(
device const float* gy [[buffer(0)]], // [N, Out]
device const float* weight [[buffer(1)]], //[Out, In]
device const bool* active_mask [[buffer(2)]], // [Out] boolean mask
device float* grad_x [[buffer(3)]], // [N, In], zeroed by caller
constant SparseLinearParams& p [[buffer(4)]],
uint2 tid [[thread_position_in_grid]])
{
uint c = tid.x;
uint n = tid.y;
if (n >= p.N || c >= p.In) return;
float acc = 0.0f;
for (uint row = 0; row < p.Out; ++row) {
if (active_mask[row]) {
acc += gy[n * p.Out + row] * weight[row * p.In + c];
}
}
grad_x[n * p.In + c] = acc;
}