LJTSG commited on
Commit
b2c7dd7
·
verified ·
1 Parent(s): c3a1bdb

Upload shaders/matmul_gemv.wgsl with huggingface_hub

Browse files
Files changed (1) hide show
  1. shaders/matmul_gemv.wgsl +63 -0
shaders/matmul_gemv.wgsl ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // matmul_gemv.wgsl — M=1 specialized matrix-vector product.
2
+ // Ported from gfx1151_inference/shaders/matmul_gemv.comp.
3
+ //
4
+ // Computes c[n] = sum_k(a[k] * B[n, k]) for n in [0, N).
5
+ // A is a vector [K]
6
+ // B is row-major [N, K]
7
+ // C is a vector [N]
8
+ //
9
+ // One WORKGROUP per output element n. Dispatch (N, 1, 1).
10
+ // Each workgroup has 64 threads cooperating on one dot product.
11
+ // Threads stride over K: thread t handles k = t, t+64, t+128, ...
12
+ // Tree reduction in shared memory yields c[n].
13
+
14
+ const WG_SIZE: u32 = 64u;
15
+
16
+ struct Params {
17
+ N: u32,
18
+ K: u32,
19
+ }
20
+
21
+ @group(0) @binding(0) var<storage, read> a_buf: array<f32>; // [K]
22
+ @group(0) @binding(1) var<storage, read> b_buf: array<f32>; // [N, K]
23
+ @group(0) @binding(2) var<storage, read_write> c_buf: array<f32>; // [N]
24
+
25
+ @group(1) @binding(0) var<uniform> params: Params;
26
+
27
+ var<workgroup> partial: array<f32, 64>; // WG_SIZE
28
+
29
+ @compute @workgroup_size(WG_SIZE)
30
+ fn main(
31
+ @builtin(workgroup_id) wg_id: vec3<u32>,
32
+ @builtin(local_invocation_id) lid: vec3<u32>
33
+ ) {
34
+ let n = wg_id.x;
35
+ let tid = lid.x;
36
+ if (n >= params.N) { return; }
37
+
38
+ let b_base = n * params.K;
39
+ var acc: f32 = 0.0;
40
+ var k: u32 = tid;
41
+ loop {
42
+ if (k >= params.K) { break; }
43
+ acc = acc + a_buf[k] * b_buf[b_base + k];
44
+ k = k + WG_SIZE;
45
+ }
46
+ partial[tid] = acc;
47
+ workgroupBarrier();
48
+
49
+ // Tree reduction
50
+ var off: u32 = WG_SIZE / 2u;
51
+ loop {
52
+ if (off == 0u) { break; }
53
+ if (tid < off) {
54
+ partial[tid] = partial[tid] + partial[tid + off];
55
+ }
56
+ workgroupBarrier();
57
+ off = off >> 1u;
58
+ }
59
+
60
+ if (tid == 0u) {
61
+ c_buf[n] = partial[0u];
62
+ }
63
+ }