LJTSG commited on
Commit
4cee88d
·
verified ·
1 Parent(s): b3e0b44

Upload shaders/elementwise_mul.wgsl with huggingface_hub

Browse files
Files changed (1) hide show
  1. shaders/elementwise_mul.wgsl +14 -0
shaders/elementwise_mul.wgsl ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // elementwise_mul.wgsl — a[i] = a[i] * b[i], in-place.
2
+ // Used in Mamba mixer for gate: gated = hidden_y * silu(gate).
3
+
4
+ struct Params { n: u32, }
5
+ @group(0) @binding(0) var<storage, read_write> a_buf: array<f32>;
6
+ @group(0) @binding(1) var<storage, read> b_buf: array<f32>;
7
+ @group(1) @binding(0) var<uniform> params: Params;
8
+
9
+ @compute @workgroup_size(64)
10
+ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
11
+ let i = gid.x;
12
+ if (i >= params.n) { return; }
13
+ a_buf[i] = a_buf[i] * b_buf[i];
14
+ }