LJTSG commited on
Commit
0116d58
·
verified ·
1 Parent(s): 4cd5770

Upload shaders/add_residual.wgsl with huggingface_hub

Browse files
Files changed (1) hide show
  1. shaders/add_residual.wgsl +14 -0
shaders/add_residual.wgsl ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // add_residual.wgsl — a[i] += b[i], in-place.
2
+ // Used for the residual connection: hidden_state += out_proj_output.
3
+
4
+ struct Params { n: u32, }
5
+ @group(0) @binding(0) var<storage, read_write> a_buf: array<f32>;
6
+ @group(0) @binding(1) var<storage, read> b_buf: array<f32>;
7
+ @group(1) @binding(0) var<uniform> params: Params;
8
+
9
+ @compute @workgroup_size(64)
10
+ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
11
+ let i = gid.x;
12
+ if (i >= params.n) { return; }
13
+ a_buf[i] = a_buf[i] + b_buf[i];
14
+ }