File size: 2,052 Bytes
bc1b8eb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
#include <metal_stdlib>
using namespace metal;

struct SparseLinearParams {
    uint32_t N;
    uint32_t In;
    uint32_t Out;
    uint32_t dummy; // alignment padding
};

kernel void sparse_linear_grad_w_float(
    device const float* x          [[buffer(0)]], // [N, In]
    device const float* gy         [[buffer(1)]], // [N, Out]
    device const bool* active_mask [[buffer(2)]], // [Out] boolean mask
    device float* grad_w           [[buffer(3)]], //[Out, In], zeroed by caller
    device float* grad_b           [[buffer(4)]], // [Out], zeroed by caller
    constant SparseLinearParams& p [[buffer(5)]],
    uint2 tid                      [[thread_position_in_grid]]) 
{
    uint c = tid.x;
    uint row = tid.y;
    
    // Bounds check
    if (row >= p.Out || c >= p.In) return;

    // The magic: if the row isn't active, the thread exits instantly.
    // No CPU sync required.
    if (!active_mask[row]) return;

    float acc = 0.0f;
    for (uint n = 0; n < p.N; ++n) {
        acc += gy[n * p.Out + row] * x[n * p.In + c];
    }
    grad_w[row * p.In + c] = acc;

    // Bias calculation (could be optimized further, but fine for now)
    if (c == 0) {
        float bacc = 0.0f;
        for (uint n = 0; n < p.N; ++n) {
            bacc += gy[n * p.Out + row];
        }
        grad_b[row] = bacc;
    }
}

kernel void sparse_linear_grad_x_float(
    device const float* gy         [[buffer(0)]], // [N, Out]
    device const float* weight     [[buffer(1)]], //[Out, In]
    device const bool* active_mask [[buffer(2)]], // [Out] boolean mask
    device float* grad_x           [[buffer(3)]], // [N, In], zeroed by caller
    constant SparseLinearParams& p [[buffer(4)]],
    uint2 tid                      [[thread_position_in_grid]]) 
{
    uint c = tid.x;
    uint n = tid.y;
    if (n >= p.N || c >= p.In) return;

    float acc = 0.0f;
    for (uint row = 0; row < p.Out; ++row) {
        if (active_mask[row]) {
            acc += gy[n * p.Out + row] * weight[row * p.In + c];
        }
    }
    grad_x[n * p.In + c] = acc;
}