LJTSG commited on
Commit
7df1089
·
verified ·
1 Parent(s): b2c7dd7

Upload shaders/rmsnorm.wgsl with huggingface_hub

Browse files
Files changed (1) hide show
  1. shaders/rmsnorm.wgsl +66 -0
shaders/rmsnorm.wgsl ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // rmsnorm.wgsl — Root Mean Square LayerNorm (no mean subtraction).
2
+ // Ported from gfx1151_inference/shaders/rmsnorm.comp.
3
+ // y[t, d] = x[t, d] / sqrt(mean(x[t, :]^2) + eps) * weight[d]
4
+ // One workgroup per row. Tree reduction for variance.
5
+
6
+ const WG_SIZE: u32 = 64u;
7
+
8
+ struct Params {
9
+ rows: u32,
10
+ D: u32,
11
+ eps: f32,
12
+ }
13
+
14
+ @group(0) @binding(0) var<storage, read> x_buf: array<f32>;
15
+ @group(0) @binding(1) var<storage, read> w_buf: array<f32>;
16
+ @group(0) @binding(2) var<storage, read_write> y_buf: array<f32>;
17
+
18
+ @group(1) @binding(0) var<uniform> params: Params;
19
+
20
+ var<workgroup> partial_sumsq: array<f32, 64>;
21
+
22
+ @compute @workgroup_size(WG_SIZE)
23
+ fn main(
24
+ @builtin(workgroup_id) wg_id: vec3<u32>,
25
+ @builtin(local_invocation_id) lid: vec3<u32>
26
+ ) {
27
+ let row = wg_id.x;
28
+ let tid = lid.x;
29
+ if (row >= params.rows) { return; }
30
+
31
+ let base = row * params.D;
32
+
33
+ // 1. Each thread accumulates its strided slice of x[row]^2
34
+ var s: f32 = 0.0;
35
+ var d: u32 = tid;
36
+ loop {
37
+ if (d >= params.D) { break; }
38
+ let v = x_buf[base + d];
39
+ s = s + v * v;
40
+ d = d + WG_SIZE;
41
+ }
42
+ partial_sumsq[tid] = s;
43
+ workgroupBarrier();
44
+
45
+ // 2. Tree reduction
46
+ var off: u32 = WG_SIZE / 2u;
47
+ loop {
48
+ if (off == 0u) { break; }
49
+ if (tid < off) {
50
+ partial_sumsq[tid] = partial_sumsq[tid] + partial_sumsq[tid + off];
51
+ }
52
+ workgroupBarrier();
53
+ off = off >> 1u;
54
+ }
55
+
56
+ let mean_sq = partial_sumsq[0u] / f32(params.D);
57
+ let scale = 1.0 / sqrt(mean_sq + params.eps);
58
+
59
+ // 3. Apply: y = x * scale * weight
60
+ d = tid;
61
+ loop {
62
+ if (d >= params.D) { break; }
63
+ y_buf[base + d] = x_buf[base + d] * scale * w_buf[d];
64
+ d = d + WG_SIZE;
65
+ }
66
+ }