File size: 462 Bytes
4976636
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
// softplus.wgsl — element-wise softplus, in-place.
// softplus(x) = max(x, 0) + log(1 + exp(-|x|))

struct Params { n: u32, }
@group(0) @binding(0) var<storage, read_write> data: array<f32>;
@group(1) @binding(0) var<uniform> params: Params;

@compute @workgroup_size(64)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
    let i = gid.x;
    if (i >= params.n) { return; }
    let x = data[i];
    data[i] = max(x, 0.0) + log(1.0 + exp(-abs(x)));
}