| |
| |
| |
| |
| @@ -992,9 +992,10 @@ struct ggml_webgpu_glu_pipeline_key { |
| ggml_glu_op glu_op; |
| ggml_type type; |
| bool split; |
| + bool inplace; |
| |
| bool operator==(const ggml_webgpu_glu_pipeline_key & other) const { |
| - return glu_op == other.glu_op && type == other.type && split == other.split; |
| + return glu_op == other.glu_op && type == other.type && split == other.split && inplace == other.inplace; |
| } |
| }; |
| |
| @@ -1004,6 +1005,7 @@ struct ggml_webgpu_glu_pipeline_key_hash { |
| ggml_webgpu_hash_combine(seed, key.glu_op); |
| ggml_webgpu_hash_combine(seed, key.type); |
| ggml_webgpu_hash_combine(seed, key.split); |
| + ggml_webgpu_hash_combine(seed, key.inplace); |
| return seed; |
| } |
| }; |
| @@ -2906,7 +2908,10 @@ class ggml_webgpu_shader_lib { |
| ggml_webgpu_glu_pipeline_key key = {}; |
| key.glu_op = ggml_get_glu_op(context.dst); |
| key.type = context.dst->type; |
| - key.split = (context.src1 != nullptr); |
| + // If src0 and src1 overlap, force NO_SPLIT mode (reads both halves from src0 with offset) |
| + key.split = (context.src1 != nullptr) && |
| + !ggml_webgpu_tensor_overlap(context.src0, context.src1); |
| + key.inplace = !key.split && ggml_webgpu_tensor_overlap(context.src0, context.dst); |
| |
| auto it = glu_pipelines.find(key); |
| if (it != glu_pipelines.end()) { |
| @@ -2961,6 +2966,10 @@ class ggml_webgpu_shader_lib { |
| variant += "_split"; |
| } else { |
| defines.push_back("NO_SPLIT"); |
| + if (key.inplace) { |
| + defines.push_back("INPLACE"); |
| + variant += "_inplace"; |
| + } |
| } |
| |
| defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size)); |
| |
| |
| |
| |
| @@ -2531,7 +2531,8 @@ static webgpu_encoded_op ggml_webgpu_glu(webgpu_context & ctx, |
| |
| auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get()); |
| |
| - const int split = (src1 != nullptr); |
| + // If src0 and src1 overlap, treat as non-split (NO_SPLIT shader reads both from src0) |
| + const int split = (src1 != nullptr) && !ggml_webgpu_tensor_overlap(src0, src1); |
| |
| std::vector<uint32_t> params = { |
| (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)), |
| @@ -2558,15 +2559,18 @@ static webgpu_encoded_op ggml_webgpu_glu(webgpu_context & ctx, |
| ggml_webgpu_u32_from_f32(ggml_get_op_params_f32(dst, 3)), // limit, for swiglu_oai |
| }; |
| |
| + const bool inplace = !split && ggml_webgpu_tensor_overlap(src0, dst); |
| + |
| std::vector<wgpu::BindGroupEntry> entries = { |
| ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src0), |
| }; |
| - uint32_t dst_binding = 1; |
| if (split) { |
| - dst_binding = 2; |
| entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, src1)); |
| + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, dst)); |
| + } else if (!inplace) { |
| + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, dst)); |
| } |
| - entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, dst_binding, dst)); |
| + // When inplace: no dst binding — shader writes to src0 directly |
| |
| uint32_t wg_x = CEIL_DIV(ggml_nelements(dst), decisions->wg_size); |
| return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x); |
| |
| |
| |
| |
| @@ -97,11 +97,18 @@ struct Params { |
| var<storage, read_write> src0: array<DataType>; |
| |
| #ifdef NO_SPLIT |
| + |
| +#ifdef INPLACE |
| +// Inplace: dst aliases src0 — use src0 for writes to avoid WebGPU aliasing violation |
| +@group(0) @binding(1) |
| +var<uniform> params: Params; |
| +#else |
| @group(0) @binding(1) |
| var<storage, read_write> dst: array<DataType>; |
| |
| @group(0) @binding(2) |
| var<uniform> params: Params; |
| +#endif |
| |
| fn a_value(base: u32) -> DataType { |
| let offset: u32 = select(0, params.ne0, params.swapped != 0); |
| @@ -151,5 +158,9 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) { |
| let i_b = params.offset_src1 + i3 * params.stride_src13 + i2 * params.stride_src12 + i1 * params.stride_src11 + i0; |
| let i_dst = params.offset_dst + i3 * params.stride_dst3 + i2 * params.stride_dst2 + i1 * params.stride_dst1 + i0; |
| |
| +#ifdef INPLACE |
| + src0[i_dst] = op(a_value(i_a), b_value(i_b)); |
| +#else |
| dst[i_dst] = op(a_value(i_a), b_value(i_b)); |
| +#endif |
| } |
|
|