File size: 462 Bytes
4976636 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 | // softplus.wgsl — element-wise softplus, in-place.
// softplus(x) = max(x, 0) + log(1 + exp(-|x|))
struct Params { n: u32, }
@group(0) @binding(0) var<storage, read_write> data: array<f32>;
@group(1) @binding(0) var<uniform> params: Params;
@compute @workgroup_size(64)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
let i = gid.x;
if (i >= params.n) { return; }
let x = data[i];
data[i] = max(x, 0.0) + log(1.0 + exp(-abs(x)));
}
|