LJTSG commited on
Commit
684eb39
·
verified ·
1 Parent(s): 14405c7

Upload wllama-glu-fix.patch with huggingface_hub

Browse files
Files changed (1) hide show
  1. wllama-glu-fix.patch +116 -0
wllama-glu-fix.patch ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ diff --git a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp
2
+ index f4c5eca0d..0fa391cca 100644
3
+ --- a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp
4
+ +++ b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp
5
+ @@ -992,9 +992,10 @@ struct ggml_webgpu_glu_pipeline_key {
6
+ ggml_glu_op glu_op;
7
+ ggml_type type;
8
+ bool split;
9
+ + bool inplace;
10
+
11
+ bool operator==(const ggml_webgpu_glu_pipeline_key & other) const {
12
+ - return glu_op == other.glu_op && type == other.type && split == other.split;
13
+ + return glu_op == other.glu_op && type == other.type && split == other.split && inplace == other.inplace;
14
+ }
15
+ };
16
+
17
+ @@ -1004,6 +1005,7 @@ struct ggml_webgpu_glu_pipeline_key_hash {
18
+ ggml_webgpu_hash_combine(seed, key.glu_op);
19
+ ggml_webgpu_hash_combine(seed, key.type);
20
+ ggml_webgpu_hash_combine(seed, key.split);
21
+ + ggml_webgpu_hash_combine(seed, key.inplace);
22
+ return seed;
23
+ }
24
+ };
25
+ @@ -2906,7 +2908,10 @@ class ggml_webgpu_shader_lib {
26
+ ggml_webgpu_glu_pipeline_key key = {};
27
+ key.glu_op = ggml_get_glu_op(context.dst);
28
+ key.type = context.dst->type;
29
+ - key.split = (context.src1 != nullptr);
30
+ + // If src0 and src1 overlap, force NO_SPLIT mode (reads both halves from src0 with offset)
31
+ + key.split = (context.src1 != nullptr) &&
32
+ + !ggml_webgpu_tensor_overlap(context.src0, context.src1);
33
+ + key.inplace = !key.split && ggml_webgpu_tensor_overlap(context.src0, context.dst);
34
+
35
+ auto it = glu_pipelines.find(key);
36
+ if (it != glu_pipelines.end()) {
37
+ @@ -2961,6 +2966,10 @@ class ggml_webgpu_shader_lib {
38
+ variant += "_split";
39
+ } else {
40
+ defines.push_back("NO_SPLIT");
41
+ + if (key.inplace) {
42
+ + defines.push_back("INPLACE");
43
+ + variant += "_inplace";
44
+ + }
45
+ }
46
+
47
+ defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
48
+ diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp
49
+ index d577b5afa..7e27113dc 100644
50
+ --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp
51
+ +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp
52
+ @@ -2531,7 +2531,8 @@ static webgpu_encoded_op ggml_webgpu_glu(webgpu_context & ctx,
53
+
54
+ auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
55
+
56
+ - const int split = (src1 != nullptr);
57
+ + // If src0 and src1 overlap, treat as non-split (NO_SPLIT shader reads both from src0)
58
+ + const int split = (src1 != nullptr) && !ggml_webgpu_tensor_overlap(src0, src1);
59
+
60
+ std::vector<uint32_t> params = {
61
+ (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)),
62
+ @@ -2558,15 +2559,18 @@ static webgpu_encoded_op ggml_webgpu_glu(webgpu_context & ctx,
63
+ ggml_webgpu_u32_from_f32(ggml_get_op_params_f32(dst, 3)), // limit, for swiglu_oai
64
+ };
65
+
66
+ + const bool inplace = !split && ggml_webgpu_tensor_overlap(src0, dst);
67
+ +
68
+ std::vector<wgpu::BindGroupEntry> entries = {
69
+ ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src0),
70
+ };
71
+ - uint32_t dst_binding = 1;
72
+ if (split) {
73
+ - dst_binding = 2;
74
+ entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, src1));
75
+ + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, dst));
76
+ + } else if (!inplace) {
77
+ + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, dst));
78
+ }
79
+ - entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, dst_binding, dst));
80
+ + // When inplace: no dst binding — shader writes to src0 directly
81
+
82
+ uint32_t wg_x = CEIL_DIV(ggml_nelements(dst), decisions->wg_size);
83
+ return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x);
84
+ diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/glu.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/glu.wgsl
85
+ index e6d7608ce..726643fe0 100644
86
+ --- a/ggml/src/ggml-webgpu/wgsl-shaders/glu.wgsl
87
+ +++ b/ggml/src/ggml-webgpu/wgsl-shaders/glu.wgsl
88
+ @@ -97,11 +97,18 @@ struct Params {
89
+ var<storage, read_write> src0: array<DataType>;
90
+
91
+ #ifdef NO_SPLIT
92
+ +
93
+ +#ifdef INPLACE
94
+ +// Inplace: dst aliases src0 — use src0 for writes to avoid WebGPU aliasing violation
95
+ +@group(0) @binding(1)
96
+ +var<uniform> params: Params;
97
+ +#else
98
+ @group(0) @binding(1)
99
+ var<storage, read_write> dst: array<DataType>;
100
+
101
+ @group(0) @binding(2)
102
+ var<uniform> params: Params;
103
+ +#endif
104
+
105
+ fn a_value(base: u32) -> DataType {
106
+ let offset: u32 = select(0, params.ne0, params.swapped != 0);
107
+ @@ -151,5 +158,9 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
108
+ let i_b = params.offset_src1 + i3 * params.stride_src13 + i2 * params.stride_src12 + i1 * params.stride_src11 + i0;
109
+ let i_dst = params.offset_dst + i3 * params.stride_dst3 + i2 * params.stride_dst2 + i1 * params.stride_dst1 + i0;
110
+
111
+ +#ifdef INPLACE
112
+ + src0[i_dst] = op(a_value(i_a), b_value(i_b));
113
+ +#else
114
+ dst[i_dst] = op(a_value(i_a), b_value(i_b));
115
+ +#endif
116
+ }