diff --git a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp index f4c5eca0d..0fa391cca 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp @@ -992,9 +992,10 @@ struct ggml_webgpu_glu_pipeline_key { ggml_glu_op glu_op; ggml_type type; bool split; + bool inplace; bool operator==(const ggml_webgpu_glu_pipeline_key & other) const { - return glu_op == other.glu_op && type == other.type && split == other.split; + return glu_op == other.glu_op && type == other.type && split == other.split && inplace == other.inplace; } }; @@ -1004,6 +1005,7 @@ struct ggml_webgpu_glu_pipeline_key_hash { ggml_webgpu_hash_combine(seed, key.glu_op); ggml_webgpu_hash_combine(seed, key.type); ggml_webgpu_hash_combine(seed, key.split); + ggml_webgpu_hash_combine(seed, key.inplace); return seed; } }; @@ -2906,7 +2908,10 @@ class ggml_webgpu_shader_lib { ggml_webgpu_glu_pipeline_key key = {}; key.glu_op = ggml_get_glu_op(context.dst); key.type = context.dst->type; - key.split = (context.src1 != nullptr); + // If src0 and src1 overlap, force NO_SPLIT mode (reads both halves from src0 with offset) + key.split = (context.src1 != nullptr) && + !ggml_webgpu_tensor_overlap(context.src0, context.src1); + key.inplace = !key.split && ggml_webgpu_tensor_overlap(context.src0, context.dst); auto it = glu_pipelines.find(key); if (it != glu_pipelines.end()) { @@ -2961,6 +2966,10 @@ class ggml_webgpu_shader_lib { variant += "_split"; } else { defines.push_back("NO_SPLIT"); + if (key.inplace) { + defines.push_back("INPLACE"); + variant += "_inplace"; + } } defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size)); diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index d577b5afa..7e27113dc 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -2531,7 +2531,8 @@ static webgpu_encoded_op ggml_webgpu_glu(webgpu_context & ctx, auto * decisions = static_cast(pipeline.context.get()); - const int split = (src1 != nullptr); + // If src0 and src1 overlap, treat as non-split (NO_SPLIT shader reads both from src0) + const int split = (src1 != nullptr) && !ggml_webgpu_tensor_overlap(src0, src1); std::vector params = { (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)), @@ -2558,15 +2559,18 @@ static webgpu_encoded_op ggml_webgpu_glu(webgpu_context & ctx, ggml_webgpu_u32_from_f32(ggml_get_op_params_f32(dst, 3)), // limit, for swiglu_oai }; + const bool inplace = !split && ggml_webgpu_tensor_overlap(src0, dst); + std::vector entries = { ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src0), }; - uint32_t dst_binding = 1; if (split) { - dst_binding = 2; entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, src1)); + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, dst)); + } else if (!inplace) { + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, dst)); } - entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, dst_binding, dst)); + // When inplace: no dst binding — shader writes to src0 directly uint32_t wg_x = CEIL_DIV(ggml_nelements(dst), decisions->wg_size); return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x); diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/glu.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/glu.wgsl index e6d7608ce..726643fe0 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/glu.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/glu.wgsl @@ -97,11 +97,18 @@ struct Params { var src0: array; #ifdef NO_SPLIT + +#ifdef INPLACE +// Inplace: dst aliases src0 — use src0 for writes to avoid WebGPU aliasing violation +@group(0) @binding(1) +var params: Params; +#else @group(0) @binding(1) var dst: array; @group(0) @binding(2) var params: Params; +#endif fn a_value(base: u32) -> DataType { let offset: u32 = select(0, params.ne0, params.swapped != 0); @@ -151,5 +158,9 @@ fn main(@builtin(global_invocation_id) gid: vec3) { let i_b = params.offset_src1 + i3 * params.stride_src13 + i2 * params.stride_src12 + i1 * params.stride_src11 + i0; let i_dst = params.offset_dst + i3 * params.stride_dst3 + i2 * params.stride_dst2 + i1 * params.stride_dst1 + i0; +#ifdef INPLACE + src0[i_dst] = op(a_value(i_a), b_value(i_b)); +#else dst[i_dst] = op(a_value(i_a), b_value(i_b)); +#endif }