| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
|
|
| #ifndef GGML_SYCL_GEMM_HPP |
| #define GGML_SYCL_GEMM_HPP |
|
|
| #include "ggml-sycl.h" |
|
|
| #if GGML_SYCL_DNNL |
|
|
| #include "dnnl.hpp" |
| #include "dnnl_sycl.hpp" |
|
|
| class DnnlGemmWrapper { |
| public: |
| using dt = dnnl::memory::data_type; |
| using tag = dnnl::memory::format_tag; |
|
|
| template<typename T> |
| static constexpr dt to_dt() { |
| if constexpr (std::is_same_v<T, float>) return dt::f32; |
| else if constexpr (std::is_same_v<T, sycl::half>) return dt::f16; |
| else static_assert(0); |
| } |
|
|
| static void gemm(ggml_backend_sycl_context & ctx, int m, int n, int k, |
| const void * a, dt at, dnnl_dim_t stra0, dnnl_dim_t stra1, dnnl_dim_t stra2, |
| const void * b, dt bt, dnnl_dim_t strb0, dnnl_dim_t strb1, dnnl_dim_t strb2, |
| void * c, dt ct, const queue_ptr & q, dnnl_dim_t batches_a, dnnl_dim_t batches_b) { |
|
|
| auto stream = ctx.stream_dnnl(q); |
| auto eng = ctx.engine_dnnl(q); |
|
|
| dnnl::memory::dims a_dims = {batches_a, m, k }; |
| dnnl::memory::dims a_strides = {stra2, stra1, stra0}; |
| const auto a_in_md = dnnl::memory::desc(a_dims, at, a_strides); |
|
|
| dnnl::memory::dims b_dims = {batches_b, k, n }; |
| dnnl::memory::dims b_strides = {strb2, strb0, strb1}; |
| const auto b_in_md = dnnl::memory::desc(b_dims, bt, b_strides); |
|
|
| dnnl::memory::dims c_dims = { std::max(batches_a, batches_b), m, n}; |
| dnnl::memory::dims c_strides = {m*n, 1, m }; |
| const auto c_md = dnnl::memory::desc(c_dims, ct, c_strides); |
| dnnl::primitive_attr primitive_attr; |
| primitive_attr.set_scratchpad_mode(dnnl::scratchpad_mode::user); |
|
|
| #ifdef GGML_SYCL_F16 |
| primitive_attr.set_fpmath_mode(dnnl::fpmath_mode::f16); |
| #endif |
|
|
| auto a_mem = dnnl::memory(a_in_md, eng, const_cast<void*>(a)); |
| auto b_mem = dnnl::memory(b_in_md, eng, const_cast<void*>(b)); |
| auto matmul_pd = dnnl::matmul::primitive_desc(eng, a_in_md, b_in_md, c_md, primitive_attr); |
| auto c_mem = dnnl::memory(matmul_pd.dst_desc(), eng, c); |
|
|
| auto scratchpad_md = matmul_pd.scratchpad_desc(); |
| auto scratchpad_mem = ctx.get_scratchpad_mem(scratchpad_md, eng, q); |
|
|
| auto matmul_prim = dnnl::matmul(matmul_pd); |
|
|
| std::unordered_map<int, dnnl::memory> matmul_args; |
| matmul_args.insert({ DNNL_ARG_SRC, a_mem }); |
| matmul_args.insert({ DNNL_ARG_WEIGHTS, b_mem }); |
|
|
| matmul_args.insert({ DNNL_ARG_DST, c_mem }); |
| matmul_args.insert({ DNNL_ARG_SCRATCHPAD, scratchpad_mem }); |
|
|
| matmul_prim.execute(stream, matmul_args); |
| } |
|
|
| static void row_gemm(ggml_backend_sycl_context & ctx, int m, int n, int k, |
| const void * a, dt at, const void * b, dt bt, void * c, dt ct, const queue_ptr & q) { |
|
|
| gemm(ctx, m, n, k, a, at, 1, k, k * m, b, bt, 1, k, n * k, c, ct, q, 1, 1); |
| } |
| }; |
|
|
| #endif |
|
|
| #endif |
|
|