| #include "roll.hpp" |
| #include "common.hpp" |
|
|
| using namespace sycl; |
|
|
| static inline int wrap_add(int i, int shift, int n) { |
|
|
| int s = i + shift; |
| return (s >= n) ? (s - n) : s; |
| } |
|
|
| static void kernel_roll_fused_i0_i1( |
| queue &q, |
| const float *src_d, |
| float *dst_d, |
| int ne0, int ne1, int ne2, int ne3, |
| int sh0, int sh1, int sh2, int sh3) |
| { |
| if (ne0 == 0 || ne1 == 0 || ne2 == 0 || ne3 == 0) return; |
|
|
|
|
| const int stride1 = ne0; |
| const int stride2 = ne0 * ne1; |
| const int stride3 = ne0 * ne1 * ne2; |
|
|
|
|
| const int shNe0 = (ne0 - sh0) % ne0; |
| const int shNe1 = (ne1 - sh1) % ne1; |
| const int shNe2 = (ne2 - sh2) % ne2; |
| const int shNe3 = (ne3 - sh3) % ne3; |
|
|
|
|
| const size_t g0 = (size_t) ne3; |
| const size_t g1 = (size_t) ne2; |
| const size_t g2 = (size_t) (ne1 * ne0); |
|
|
| const range<3> global{ g0, g1, g2 }; |
|
|
| q.submit([&](handler &h) { |
| h.parallel_for(global, [=](id<3> idx) { |
| const int i3 = (int) idx[0]; |
| const int i2 = (int) idx[1]; |
|
|
| const int fused = (int) idx[2]; |
| const int i1 = fused / ne0; |
| const int i0 = fused - i1 * ne0; |
|
|
|
|
| const int idx_dst = i0 |
| + i1 * stride1 |
| + i2 * stride2 |
| + i3 * stride3; |
|
|
|
|
| const int s0 = wrap_add(i0, shNe0, ne0); |
| const int s1 = wrap_add(i1, shNe1, ne1); |
| const int s2 = wrap_add(i2, shNe2, ne2); |
| const int s3 = wrap_add(i3, shNe3, ne3); |
|
|
| const int idx_src = s0 |
| + s1 * stride1 |
| + s2 * stride2 |
| + s3 * stride3; |
|
|
| dst_d[idx_dst] = src_d[idx_src]; |
| }); |
| }); |
| } |
|
|
| void ggml_sycl_roll(ggml_backend_sycl_context & ctx, ggml_tensor *dst) { |
| GGML_ASSERT(dst->type == GGML_TYPE_F32); |
|
|
| const ggml_tensor *src = dst->src[0]; |
| GGML_ASSERT(src && src->type == GGML_TYPE_F32); |
|
|
| const int ne0 = (int) dst->ne[0]; |
| const int ne1 = (int) dst->ne[1]; |
| const int ne2 = (int) dst->ne[2]; |
| const int ne3 = (int) dst->ne[3]; |
|
|
| const int32_t *params = (const int32_t *) dst->op_params; |
| int shift0 = params[0]; |
| int shift1 = params[1]; |
| int shift2 = params[2]; |
| int shift3 = params[3]; |
|
|
|
|
| if ((shift0 | shift1 | shift2 | shift3) == 0) { |
| const size_t nb = ggml_nbytes(src); |
| queue *q = ctx.stream(); |
| SYCL_CHECK(CHECK_TRY_ERROR(q->memcpy(dst->data, src->data, nb))); |
| return; |
| } |
|
|
| auto norm = [](int sh, int n) -> int { |
| if (n <= 0) return 0; |
| sh %= n; |
| if (sh < 0) sh += n; |
| return sh; |
| }; |
| shift0 = norm(shift0, ne0); |
| shift1 = norm(shift1, ne1); |
| shift2 = norm(shift2, ne2); |
| shift3 = norm(shift3, ne3); |
|
|
| try { |
| queue *q = ctx.stream(); |
|
|
| const float *src_d = (const float *) src->data; |
| float *dst_d = (float *) dst->data; |
| GGML_ASSERT(src_d && dst_d); |
|
|
| kernel_roll_fused_i0_i1( |
| *q, src_d, dst_d, |
| ne0, ne1, ne2, ne3, |
| shift0, shift1, shift2, shift3 |
| ); |
| } catch (const std::exception &e) { |
| std::fprintf(stderr, "[SYCL-ROLL] ERROR: %s\n", e.what()); |
| throw; |
| } |
| } |
|
|