mamba-webgpu / shaders /softplus.wgsl
LJTSG's picture
Upload shaders/softplus.wgsl with huggingface_hub
4976636 verified
raw
history blame contribute delete
462 Bytes
// softplus.wgsl — element-wise softplus, in-place.
// softplus(x) = max(x, 0) + log(1 + exp(-|x|))
struct Params { n: u32, }
@group(0) @binding(0) var<storage, read_write> data: array<f32>;
@group(1) @binding(0) var<uniform> params: Params;
@compute @workgroup_size(64)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
let i = gid.x;
if (i >= params.n) { return; }
let x = data[i];
data[i] = max(x, 0.0) + log(1.0 + exp(-abs(x)));
}