gemma-webgpu / wllama-glu-fix.patch
LJTSG's picture
Upload wllama-glu-fix.patch with huggingface_hub
684eb39 verified
diff --git a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp
index f4c5eca0d..0fa391cca 100644
--- a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp
+++ b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp
@@ -992,9 +992,10 @@ struct ggml_webgpu_glu_pipeline_key {
ggml_glu_op glu_op;
ggml_type type;
bool split;
+ bool inplace;
bool operator==(const ggml_webgpu_glu_pipeline_key & other) const {
- return glu_op == other.glu_op && type == other.type && split == other.split;
+ return glu_op == other.glu_op && type == other.type && split == other.split && inplace == other.inplace;
}
};
@@ -1004,6 +1005,7 @@ struct ggml_webgpu_glu_pipeline_key_hash {
ggml_webgpu_hash_combine(seed, key.glu_op);
ggml_webgpu_hash_combine(seed, key.type);
ggml_webgpu_hash_combine(seed, key.split);
+ ggml_webgpu_hash_combine(seed, key.inplace);
return seed;
}
};
@@ -2906,7 +2908,10 @@ class ggml_webgpu_shader_lib {
ggml_webgpu_glu_pipeline_key key = {};
key.glu_op = ggml_get_glu_op(context.dst);
key.type = context.dst->type;
- key.split = (context.src1 != nullptr);
+ // If src0 and src1 overlap, force NO_SPLIT mode (reads both halves from src0 with offset)
+ key.split = (context.src1 != nullptr) &&
+ !ggml_webgpu_tensor_overlap(context.src0, context.src1);
+ key.inplace = !key.split && ggml_webgpu_tensor_overlap(context.src0, context.dst);
auto it = glu_pipelines.find(key);
if (it != glu_pipelines.end()) {
@@ -2961,6 +2966,10 @@ class ggml_webgpu_shader_lib {
variant += "_split";
} else {
defines.push_back("NO_SPLIT");
+ if (key.inplace) {
+ defines.push_back("INPLACE");
+ variant += "_inplace";
+ }
}
defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp
index d577b5afa..7e27113dc 100644
--- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp
+++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp
@@ -2531,7 +2531,8 @@ static webgpu_encoded_op ggml_webgpu_glu(webgpu_context & ctx,
auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
- const int split = (src1 != nullptr);
+ // If src0 and src1 overlap, treat as non-split (NO_SPLIT shader reads both from src0)
+ const int split = (src1 != nullptr) && !ggml_webgpu_tensor_overlap(src0, src1);
std::vector<uint32_t> params = {
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)),
@@ -2558,15 +2559,18 @@ static webgpu_encoded_op ggml_webgpu_glu(webgpu_context & ctx,
ggml_webgpu_u32_from_f32(ggml_get_op_params_f32(dst, 3)), // limit, for swiglu_oai
};
+ const bool inplace = !split && ggml_webgpu_tensor_overlap(src0, dst);
+
std::vector<wgpu::BindGroupEntry> entries = {
ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src0),
};
- uint32_t dst_binding = 1;
if (split) {
- dst_binding = 2;
entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, src1));
+ entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, dst));
+ } else if (!inplace) {
+ entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, dst));
}
- entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, dst_binding, dst));
+ // When inplace: no dst binding — shader writes to src0 directly
uint32_t wg_x = CEIL_DIV(ggml_nelements(dst), decisions->wg_size);
return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x);
diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/glu.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/glu.wgsl
index e6d7608ce..726643fe0 100644
--- a/ggml/src/ggml-webgpu/wgsl-shaders/glu.wgsl
+++ b/ggml/src/ggml-webgpu/wgsl-shaders/glu.wgsl
@@ -97,11 +97,18 @@ struct Params {
var<storage, read_write> src0: array<DataType>;
#ifdef NO_SPLIT
+
+#ifdef INPLACE
+// Inplace: dst aliases src0 — use src0 for writes to avoid WebGPU aliasing violation
+@group(0) @binding(1)
+var<uniform> params: Params;
+#else
@group(0) @binding(1)
var<storage, read_write> dst: array<DataType>;
@group(0) @binding(2)
var<uniform> params: Params;
+#endif
fn a_value(base: u32) -> DataType {
let offset: u32 = select(0, params.ne0, params.swapped != 0);
@@ -151,5 +158,9 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
let i_b = params.offset_src1 + i3 * params.stride_src13 + i2 * params.stride_src12 + i1 * params.stride_src11 + i0;
let i_dst = params.offset_dst + i3 * params.stride_dst3 + i2 * params.stride_dst2 + i1 * params.stride_dst1 + i0;
+#ifdef INPLACE
+ src0[i_dst] = op(a_value(i_a), b_value(i_b));
+#else
dst[i_dst] = op(a_value(i_a), b_value(i_b));
+#endif
}