File size: 5,193 Bytes
684eb39
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
diff --git a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp
index f4c5eca0d..0fa391cca 100644
--- a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp
+++ b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp
@@ -992,9 +992,10 @@ struct ggml_webgpu_glu_pipeline_key {
     ggml_glu_op glu_op;
     ggml_type   type;
     bool        split;
+    bool        inplace;
 
     bool operator==(const ggml_webgpu_glu_pipeline_key & other) const {
-        return glu_op == other.glu_op && type == other.type && split == other.split;
+        return glu_op == other.glu_op && type == other.type && split == other.split && inplace == other.inplace;
     }
 };
 
@@ -1004,6 +1005,7 @@ struct ggml_webgpu_glu_pipeline_key_hash {
         ggml_webgpu_hash_combine(seed, key.glu_op);
         ggml_webgpu_hash_combine(seed, key.type);
         ggml_webgpu_hash_combine(seed, key.split);
+        ggml_webgpu_hash_combine(seed, key.inplace);
         return seed;
     }
 };
@@ -2906,7 +2908,10 @@ class ggml_webgpu_shader_lib {
         ggml_webgpu_glu_pipeline_key key = {};
         key.glu_op                       = ggml_get_glu_op(context.dst);
         key.type                         = context.dst->type;
-        key.split                        = (context.src1 != nullptr);
+        // If src0 and src1 overlap, force NO_SPLIT mode (reads both halves from src0 with offset)
+        key.split                        = (context.src1 != nullptr) &&
+                                           !ggml_webgpu_tensor_overlap(context.src0, context.src1);
+        key.inplace                      = !key.split && ggml_webgpu_tensor_overlap(context.src0, context.dst);
 
         auto it = glu_pipelines.find(key);
         if (it != glu_pipelines.end()) {
@@ -2961,6 +2966,10 @@ class ggml_webgpu_shader_lib {
             variant += "_split";
         } else {
             defines.push_back("NO_SPLIT");
+            if (key.inplace) {
+                defines.push_back("INPLACE");
+                variant += "_inplace";
+            }
         }
 
         defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp
index d577b5afa..7e27113dc 100644
--- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp
+++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp
@@ -2531,7 +2531,8 @@ static webgpu_encoded_op ggml_webgpu_glu(webgpu_context & ctx,
 
     auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
 
-    const int split = (src1 != nullptr);
+    // If src0 and src1 overlap, treat as non-split (NO_SPLIT shader reads both from src0)
+    const int split = (src1 != nullptr) && !ggml_webgpu_tensor_overlap(src0, src1);
 
     std::vector<uint32_t> params = {
         (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)),
@@ -2558,15 +2559,18 @@ static webgpu_encoded_op ggml_webgpu_glu(webgpu_context & ctx,
         ggml_webgpu_u32_from_f32(ggml_get_op_params_f32(dst, 3)),  // limit, for swiglu_oai
     };
 
+    const bool inplace = !split && ggml_webgpu_tensor_overlap(src0, dst);
+
     std::vector<wgpu::BindGroupEntry> entries = {
         ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src0),
     };
-    uint32_t dst_binding = 1;
     if (split) {
-        dst_binding = 2;
         entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, src1));
+        entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, dst));
+    } else if (!inplace) {
+        entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, dst));
     }
-    entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, dst_binding, dst));
+    // When inplace: no dst binding — shader writes to src0 directly
 
     uint32_t wg_x = CEIL_DIV(ggml_nelements(dst), decisions->wg_size);
     return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x);
diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/glu.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/glu.wgsl
index e6d7608ce..726643fe0 100644
--- a/ggml/src/ggml-webgpu/wgsl-shaders/glu.wgsl
+++ b/ggml/src/ggml-webgpu/wgsl-shaders/glu.wgsl
@@ -97,11 +97,18 @@ struct Params {
 var<storage, read_write> src0: array<DataType>;
 
 #ifdef NO_SPLIT
+
+#ifdef INPLACE
+// Inplace: dst aliases src0 — use src0 for writes to avoid WebGPU aliasing violation
+@group(0) @binding(1)
+var<uniform> params: Params;
+#else
 @group(0) @binding(1)
 var<storage, read_write> dst: array<DataType>;
 
 @group(0) @binding(2)
 var<uniform> params: Params;
+#endif
 
 fn a_value(base: u32) -> DataType {
     let offset: u32 = select(0, params.ne0, params.swapped != 0);
@@ -151,5 +158,9 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
     let i_b = params.offset_src1 + i3 * params.stride_src13 + i2 * params.stride_src12 + i1 * params.stride_src11 + i0;
     let i_dst = params.offset_dst + i3 * params.stride_dst3 + i2 * params.stride_dst2 + i1 * params.stride_dst1 + i0;
 
+#ifdef INPLACE
+    src0[i_dst] = op(a_value(i_a), b_value(i_b));
+#else
     dst[i_dst] = op(a_value(i_a), b_value(i_b));
+#endif
 }