| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
|
|
| #include <algorithm> |
| #include <assert.h> |
| #include <atomic> |
| #include <cinttypes> |
| #include <cstddef> |
| #include <cstdint> |
| #include <cstdlib> |
| #include <float.h> |
| #include <limits> |
| #include <stdint.h> |
| #include <stdio.h> |
| #include <vector> |
| #include <cmath> |
| #include <iostream> |
| #include <fstream> |
| #include <stdio.h> |
| #include <stdlib.h> |
| #include <regex> |
|
|
| #include <sycl/sycl.hpp> |
| #if defined(GGML_SYCL_GRAPH) && SYCL_EXT_ONEAPI_ASYNC_MEMORY_ALLOC |
| # include <sycl/ext/oneapi/experimental/async_alloc/async_alloc.hpp> |
| #endif |
| #include <sycl/half_type.hpp> |
|
|
| #include "ggml.h" |
| #include "ggml-sycl.h" |
| #include "ggml-impl.h" |
| #include "ggml-backend-impl.h" |
|
|
| #include "ggml-sycl/add-id.hpp" |
| #include "ggml-sycl/backend.hpp" |
| #include "ggml-sycl/common.hpp" |
| #include "ggml-sycl/element_wise.hpp" |
| #include "ggml-sycl/gemm.hpp" |
| #include "ggml-sycl/getrows.hpp" |
| #include "ggml-sycl/norm.hpp" |
| #include "ggml-sycl/presets.hpp" |
| #include "ggml-sycl/quantize.hpp" |
| #include "ggml-sycl/repeat_back.hpp" |
| #include "ggml-sycl/set_rows.hpp" |
| #include "ggml-sycl/set.hpp" |
| #include "ggml-sycl/ssm_conv.hpp" |
| #include "ggml-sycl/sycl_hw.hpp" |
|
|
|
|
| static bool g_sycl_loaded = false; |
| int g_ggml_sycl_debug = 0; |
| int g_ggml_sycl_disable_optimize = 0; |
| int g_ggml_sycl_disable_graph = 0; |
| int g_ggml_sycl_disable_dnn = 0; |
| int g_ggml_sycl_prioritize_dmmv = 0; |
| int g_ggml_sycl_use_async_mem_op = 0; |
| int g_ggml_sycl_enable_flash_attention = 1; |
|
|
|
|
| static ggml_sycl_device_info ggml_sycl_init() { |
| ggml_sycl_device_info info = {}; |
|
|
| info.device_count = dpct::dev_mgr::instance().device_count(); |
| if (info.device_count == 0) { |
| GGML_LOG_ERROR("%s: failed to initialize: %s\n", GGML_SYCL_NAME, __func__); |
| return info; |
| } |
|
|
| GGML_ASSERT(info.device_count <= GGML_SYCL_MAX_DEVICES); |
|
|
| int64_t total_vram = 0; |
| |
| |
| |
| |
| |
| |
| for (int i = 0; i < info.device_count; ++i) { |
| info.devices[i].vmm = 0; |
| dpct::device_info prop; |
| sycl::device device = dpct::dev_mgr::instance().get_device(i); |
|
|
| SYCL_CHECK(CHECK_TRY_ERROR(dpct::get_device_info( |
| prop, device))); |
|
|
| info.default_tensor_split[i] = total_vram; |
| total_vram += prop.get_global_mem_size(); |
|
|
| info.devices[i].cc = |
| 100 * prop.get_major_version() + 10 * prop.get_minor_version(); |
| info.devices[i].nsm = prop.get_max_compute_units() / 16; |
| info.devices[i].opt_feature.reorder = device.ext_oneapi_architecture_is(syclex::arch_category::intel_gpu); |
| info.devices[i].smpbo = prop.get_local_mem_size(); |
| info.devices[i].warp_size = WARP_SIZE; |
|
|
| info.max_work_group_sizes[i] = prop.get_max_work_group_size(); |
| info.devices[i].max_wg_per_cu = info.max_work_group_sizes[i] / prop.get_max_compute_units(); |
|
|
| } |
|
|
| for (int id = 0; id < info.device_count; ++id) { |
| info.default_tensor_split[id] /= total_vram; |
| } |
| return info; |
| } |
|
|
| const ggml_sycl_device_info & ggml_sycl_info() { |
| static ggml_sycl_device_info info = ggml_sycl_init(); |
| return info; |
| } |
|
|
| static void print_device_detail(int id, sycl::device &device, std::string device_type) { |
|
|
| dpct::device_info prop; |
| SYCL_CHECK(CHECK_TRY_ERROR( |
| dpct::get_device_info(prop, device))); |
|
|
| std::string version; |
| version += std::to_string(prop.get_major_version()); |
| version += "."; |
| version += std::to_string(prop.get_minor_version()); |
|
|
| device_type = std::regex_replace(device_type, std::regex("ext_oneapi_"), ""); |
| std::string name = std::string(prop.get_name()); |
| name = std::regex_replace(name, std::regex("\\(R\\)"), ""); |
| name = std::regex_replace(name, std::regex("\\(TM\\)"), ""); |
|
|
| auto global_mem_size = prop.get_global_mem_size()/1000000; |
| GGML_LOG_INFO("|%2d|%19s|%39s|%7s|%7d|%8d|%5d|%6luM|%21s|\n", id, device_type.c_str(), |
| name.c_str(), version.c_str(), prop.get_max_compute_units(), |
| prop.get_max_work_group_size(), prop.get_max_sub_group_size(), |
| global_mem_size, device.get_info<sycl::info::device::driver_version>().c_str()); |
| } |
|
|
| static void print_device_opt_feature(int device_count) { |
| GGML_LOG_INFO("SYCL Optimization Feature:\n"); |
| GGML_LOG_INFO( |
| "|ID| Device Type|Reorder|\n"); |
| GGML_LOG_INFO( |
| "|--|-------------------|-------|\n"); |
| std::map<std::string, size_t> DeviceNums; |
| for (int id = 0; id < device_count; ++id) { |
| sycl::device device = dpct::dev_mgr::instance().get_device(id); |
| std::string backend_type = get_device_backend_and_type(device); |
| int type_id = DeviceNums[backend_type]++; |
| std::stringstream device_type; |
| device_type << "[" << backend_type << ":" << std::to_string(type_id) |
| << "]"; |
| std::string device_type_s = device_type.str(); |
| device_type_s = std::regex_replace(device_type_s, std::regex("ext_oneapi_"), ""); |
| GGML_LOG_INFO("|%2d|%19s|%7s|\n", id, device_type_s.c_str(), |
| ggml_sycl_info().devices[id].opt_feature.reorder ? "Y": "N"); |
| } |
|
|
| } |
| void ggml_backend_sycl_print_sycl_devices() { |
| GGML_SYCL_DEBUG("[SYCL] call ggml_backend_sycl_print_sycl_devices\n"); |
| int device_count = dpct::dev_mgr::instance().device_count(); |
| std::map<std::string, size_t> DeviceNums; |
| GGML_LOG_INFO("Found %d SYCL devices:\n", device_count); |
|
|
| GGML_LOG_INFO( |
| "| | | | " |
| " |Max | |Max |Global | |\n"); |
| GGML_LOG_INFO( |
| "| | | | " |
| " |compute|Max work|sub |mem | |\n"); |
| GGML_LOG_INFO( |
| "|ID| Device Type| " |
| "Name|Version|units |group |group|size | Driver version|\n"); |
| GGML_LOG_INFO( |
| "|--|-------------------|---------------------------------------|------" |
| "-|-------|--------|-----|-------|---------------------|\n"); |
|
|
| for (int id = 0; id < device_count; ++id) { |
| sycl::device device = dpct::dev_mgr::instance().get_device(id); |
| std::string backend_type = get_device_backend_and_type(device); |
| int type_id = DeviceNums[backend_type]++; |
| std::stringstream device_type; |
| device_type << "[" << backend_type << ":" << std::to_string(type_id) |
| << "]"; |
| print_device_detail(id, device, device_type.str()); |
| } |
|
|
| print_device_opt_feature(device_count); |
| } |
|
|
| static inline int get_sycl_env(const char *env_name, int default_val) { |
| char *user_device_string = getenv(env_name); |
| int user_number = default_val; |
|
|
| unsigned n; |
| if (user_device_string != NULL && |
| sscanf(user_device_string, " %u", &n) == 1) { |
| user_number = (int)n; |
| } else { |
| user_number = default_val; |
| } |
| return user_number; |
| } |
|
|
| static void ggml_check_sycl() try { |
| static bool initialized = false; |
|
|
| if (!initialized) { |
| g_ggml_sycl_debug = get_sycl_env("GGML_SYCL_DEBUG", 0); |
| g_ggml_sycl_disable_optimize = get_sycl_env("GGML_SYCL_DISABLE_OPT", 0); |
| g_ggml_sycl_disable_graph = get_sycl_env("GGML_SYCL_DISABLE_GRAPH", 1); |
| g_ggml_sycl_disable_dnn = get_sycl_env("GGML_SYCL_DISABLE_DNN", 0); |
| g_ggml_sycl_prioritize_dmmv = get_sycl_env("GGML_SYCL_PRIORITIZE_DMMV", 0); |
|
|
| #ifdef SYCL_FLASH_ATTN |
| g_ggml_sycl_enable_flash_attention = get_sycl_env("GGML_SYCL_ENABLE_FLASH_ATTN", 1); |
| #else |
| g_ggml_sycl_enable_flash_attention = 0; |
| #endif |
|
|
| GGML_SYCL_DEBUG("[SYCL] call ggml_check_sycl\n"); |
|
|
| GGML_LOG_INFO("Build with Macros:\n"); |
| #if defined(GGML_SYCL_FORCE_MMQ) |
| GGML_LOG_INFO(" GGML_SYCL_FORCE_MMQ: yes\n"); |
| #else |
| GGML_LOG_INFO(" GGML_SYCL_FORCE_MMQ: no\n"); |
| #endif |
| #if defined(GGML_SYCL_F16) |
| GGML_LOG_INFO(" GGML_SYCL_F16: yes\n"); |
| #else |
| GGML_LOG_INFO(" GGML_SYCL_F16: no\n"); |
| #endif |
| #if defined(GGML_SYCL_GRAPH) |
| GGML_LOG_INFO(" GGML_SYCL_GRAPH: yes\n"); |
| #else |
| GGML_LOG_INFO(" GGML_SYCL_GRAPH: no\n"); |
| #endif |
| #if defined(GGML_SYCL_DNNL) |
| GGML_LOG_INFO(" GGML_SYCL_DNNL: yes\n"); |
| #else |
| GGML_LOG_INFO(" GGML_SYCL_DNNL: no\n"); |
| #endif |
|
|
| GGML_LOG_INFO("Running with Environment Variables:\n"); |
| GGML_LOG_INFO(" GGML_SYCL_DEBUG: %d\n", g_ggml_sycl_debug); |
| GGML_LOG_INFO(" GGML_SYCL_DISABLE_OPT: %d\n", g_ggml_sycl_disable_optimize); |
| #ifdef GGML_SYCL_GRAPH |
| GGML_LOG_INFO(" GGML_SYCL_DISABLE_GRAPH: %d\n", g_ggml_sycl_disable_graph); |
| #else |
| GGML_LOG_INFO(" GGML_SYCL_DISABLE_GRAPH: graph disabled by compile flag\n"); |
| #endif |
| #if GGML_SYCL_DNNL |
| GGML_LOG_INFO(" GGML_SYCL_DISABLE_DNN: %d\n", g_ggml_sycl_disable_dnn); |
| #else |
| GGML_LOG_INFO(" GGML_SYCL_DISABLE_DNN: DNN disabled by compile flag\n"); |
| #endif |
| GGML_LOG_INFO(" GGML_SYCL_PRIORITIZE_DMMV: %d\n", g_ggml_sycl_prioritize_dmmv); |
|
|
| #ifdef SYCL_FLASH_ATTN |
| GGML_LOG_INFO(" GGML_SYCL_ENABLE_FLASH_ATTN: %d\n", g_ggml_sycl_enable_flash_attention); |
| #else |
| GGML_LOG_INFO(" GGML_SYCL_ENABLE_FLASH_ATTN: %d disabled by compile flag\n", |
| g_ggml_sycl_enable_flash_attention); |
| #endif |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| #if defined(GGML_SYCL_GRAPH) && SYCL_EXT_ONEAPI_ASYNC_MEMORY_ALLOC |
| g_ggml_sycl_use_async_mem_op = !g_ggml_sycl_disable_graph; |
| if (g_ggml_sycl_use_async_mem_op) { |
| for (unsigned int i = 0; i < dpct::dev_mgr::instance().device_count(); ++i) { |
| if (!dpct::dev_mgr::instance().get_device(i).has(sycl::aspect::ext_oneapi_async_memory_alloc)) { |
| g_ggml_sycl_use_async_mem_op = 0; |
| break; |
| } |
| } |
| } |
| #endif |
| if (CHECK_TRY_ERROR(g_all_sycl_device_count = |
| dpct::dev_mgr::instance().device_count()) != 0) { |
| initialized = true; |
| g_sycl_loaded = false; |
| return; |
| } |
| GGML_ASSERT(g_all_sycl_device_count <= GGML_SYCL_MAX_DEVICES); |
|
|
| initialized = true; |
| g_sycl_loaded = true; |
| ggml_backend_sycl_print_sycl_devices(); |
| } |
| } |
| catch (sycl::exception const &exc) { |
| std::cerr << exc.what() << "Exception caught at file:" << __FILE__ |
| << ", line:" << __LINE__ << std::endl; |
| std::exit(1); |
| } |
|
|
| |
| |
| |
| |
| inline void check_allow_gpu_index(const int device_index) { |
| if (device_index >= ggml_sycl_info().device_count) { |
| char error_buf[256]; |
| snprintf( |
| error_buf, |
| sizeof(error_buf), |
| "%s error: device_index:%d is out of range: [0-%d]", |
| __func__, |
| device_index, |
| ggml_sycl_info().device_count - 1); |
| GGML_LOG_ERROR("%s\n", error_buf); |
| assert(false); |
| } |
| } |
|
|
| GGML_API void ggml_backend_sycl_get_gpu_list(int *id_list, int max_len) try { |
| GGML_SYCL_DEBUG("[SYCL] call ggml_backend_sycl_get_gpu_list\n"); |
| for(int i=0;i<max_len;i++) id_list[i] = -1; |
|
|
| for (int i=0;i< ggml_sycl_info().device_count;i++){ |
| if (i>=max_len) break; |
| id_list[i] = i; |
| } |
| return; |
| } |
| catch (sycl::exception const &exc) { |
| std::cerr << exc.what() << "Exception caught at file:" << __FILE__ |
| << ", line:" << __LINE__ << std::endl; |
| std::exit(1); |
| } |
|
|
| |
|
|
| struct ggml_backend_sycl_buffer_context { |
| int device; |
| void * dev_ptr = nullptr; |
| queue_ptr stream; |
| std::string name; |
| optimize_feature opt_feature; |
| std::vector<ggml_tensor_extra_gpu *> tensor_extras; |
|
|
| ggml_backend_sycl_buffer_context(int device, void * dev_ptr, queue_ptr stream) : |
| device(device), dev_ptr(dev_ptr), stream(stream) { |
| check_allow_gpu_index(device); |
| name = (GGML_SYCL_NAME + std::to_string(device)); |
| opt_feature = ggml_sycl_info().devices[device].opt_feature; |
| } |
|
|
| ~ggml_backend_sycl_buffer_context() { |
| if (dev_ptr != nullptr) { |
| ggml_sycl_set_device(device); |
| SYCL_CHECK(CHECK_TRY_ERROR(sycl::free(dev_ptr, *stream))); |
| } |
|
|
| |
| for (ggml_tensor_extra_gpu * extra : tensor_extras) { |
| release_extra_gpu(extra); |
| } |
|
|
| } |
| }; |
|
|
| static const char * ggml_backend_sycl_buffer_type_get_name(ggml_backend_buffer_type_t buft); |
|
|
| static bool ggml_backend_buffer_is_sycl(ggml_backend_buffer_t buffer) { |
| return buffer->buft->iface.get_name == ggml_backend_sycl_buffer_type_get_name; |
| } |
|
|
| static void |
| ggml_backend_sycl_buffer_free_buffer(ggml_backend_buffer_t buffer) try { |
| ggml_backend_sycl_buffer_context * ctx = ( ggml_backend_sycl_buffer_context *)buffer->context; |
| ggml_sycl_set_device(ctx->device); |
|
|
| delete ctx; |
| } |
| catch (sycl::exception const &exc) { |
| std::cerr << exc.what() << "Exception caught at file:" << __FILE__ |
| << ", line:" << __LINE__ << std::endl; |
| std::exit(1); |
| } |
|
|
| static void * ggml_backend_sycl_buffer_get_base(ggml_backend_buffer_t buffer) { |
| ggml_backend_sycl_buffer_context * ctx = ( ggml_backend_sycl_buffer_context *)buffer->context; |
| return ctx->dev_ptr; |
| } |
|
|
| static enum ggml_status |
| ggml_backend_sycl_buffer_init_tensor(ggml_backend_buffer_t buffer, |
| ggml_tensor *tensor) try { |
| GGML_SYCL_DEBUG("[SYCL] call %s", __func__); |
| GGML_SYCL_DEBUG("%s", debug_get_tensor_str(": tensor", tensor, "\n").c_str()); |
| ggml_backend_sycl_buffer_context * ctx = (ggml_backend_sycl_buffer_context *)buffer->context; |
|
|
| if (tensor->view_src != NULL) { |
| assert(tensor->view_src->buffer->buft == buffer->buft); |
| return GGML_STATUS_SUCCESS; |
| } |
| if ((tensor->type == GGML_TYPE_Q4_0 || tensor->type == GGML_TYPE_Q4_K || tensor->type == GGML_TYPE_Q6_K) && |
| !g_ggml_sycl_disable_optimize) { |
| ggml_tensor_extra_gpu * extra = new ggml_tensor_extra_gpu{}; |
| tensor->extra = extra; |
| ctx->tensor_extras.push_back(extra); |
| } |
|
|
| if (ggml_is_quantized(tensor->type)) { |
| |
| size_t original_size = ggml_nbytes(tensor); |
| size_t padded_size = ggml_backend_buft_get_alloc_size(buffer->buft, tensor); |
|
|
| if (padded_size > original_size && tensor->view_src == nullptr) { |
| SYCL_CHECK(CHECK_TRY_ERROR(ctx->stream->memset( |
| (char *)tensor->data + original_size, 0, |
| padded_size - original_size).wait())); |
| } |
| } |
| return GGML_STATUS_SUCCESS; |
| } |
| catch (sycl::exception const &exc) { |
| std::cerr << exc.what() << "Exception caught at file:" << __FILE__ |
| << ", line:" << __LINE__ << std::endl; |
| std::exit(1); |
| } |
|
|
| static void ggml_backend_sycl_buffer_set_tensor(ggml_backend_buffer_t buffer, |
| ggml_tensor *tensor, |
| const void *data, size_t offset, |
| size_t size) try { |
| GGML_SYCL_DEBUG("[SYCL] call %s", __func__); |
| GGML_SYCL_DEBUG("%s", debug_get_tensor_str(": tensor", tensor).c_str()); |
| GGML_SYCL_DEBUG(" size=%zu offset=%zu\n", size, offset); |
| ggml_backend_sycl_buffer_context * ctx = ( ggml_backend_sycl_buffer_context *)buffer->context; |
| ggml_sycl_set_device(ctx->device); |
| auto stream = &(dpct::dev_mgr::instance().get_device(ctx->device).default_queue()); |
| SYCL_CHECK(CHECK_TRY_ERROR(dpct::dev_mgr::instance().get_device(ctx->device).queues_wait_and_throw())); |
| #ifndef _WIN32 |
| |
| |
| char * host_buf = (char *) malloc(size); |
| memcpy(host_buf, data, size); |
| SYCL_CHECK(CHECK_TRY_ERROR((*stream).memcpy((char *) tensor->data + offset, host_buf, size).wait())); |
| free(host_buf); |
| #else |
| SYCL_CHECK(CHECK_TRY_ERROR((*stream).memcpy((char *) tensor->data + offset, data, size).wait())); |
| #endif |
| } |
| catch (sycl::exception const &exc) { |
| std::cerr << exc.what() << "Exception caught at file:" << __FILE__ |
| << ", line:" << __LINE__ << std::endl; |
| std::exit(1); |
| } |
|
|
| static void ggml_backend_sycl_buffer_get_tensor(ggml_backend_buffer_t buffer, |
| const ggml_tensor *tensor, |
| void *data, size_t offset, |
| size_t size) try { |
| GGML_SYCL_DEBUG("[SYCL] call %s", __func__); |
| GGML_SYCL_DEBUG("%s", debug_get_tensor_str(": tensor", tensor).c_str()); |
| GGML_SYCL_DEBUG(" size=%zu offset=%zu\n", size, offset); |
| ggml_backend_sycl_buffer_context * ctx = ( ggml_backend_sycl_buffer_context *)buffer->context; |
|
|
| ggml_sycl_set_device(ctx->device); |
| auto stream = dpct::dev_mgr::instance().get_device(ctx->device).default_queue(); |
|
|
| SYCL_CHECK(CHECK_TRY_ERROR( |
| stream.memcpy(data, (const char *)tensor->data + offset, size) |
| .wait())); |
| } |
| catch (sycl::exception const &exc) { |
| std::cerr << exc.what() << "Exception caught at file:" << __FILE__ |
| << ", line:" << __LINE__ << std::endl; |
| std::exit(1); |
| } |
|
|
| static void dev2dev_memcpy(sycl::queue &q_dst, sycl::queue &q_src, void *ptr_dst, |
| const void *ptr_src, size_t size) { |
| char *host_buf = (char *)malloc(size); |
| q_src.memcpy(host_buf, (const char *)ptr_src, size).wait(); |
| q_dst.memcpy((char *)ptr_dst, host_buf, size).wait(); |
| free(host_buf); |
| } |
|
|
| static bool |
| ggml_backend_sycl_buffer_cpy_tensor(ggml_backend_buffer_t buffer, |
| const ggml_tensor *src, |
| ggml_tensor *dst) try { |
| bool is_cpy_supported = ggml_backend_buffer_is_sycl(src->buffer); |
| GGML_SYCL_DEBUG("[SYCL] call %s", __func__); |
| GGML_SYCL_DEBUG("%s", debug_get_tensor_str(": dst", dst).c_str()); |
| GGML_SYCL_DEBUG("%s", debug_get_tensor_str(" src", src).c_str()); |
| GGML_SYCL_DEBUG(" is_cpy_supported=%d\n", is_cpy_supported); |
| if (is_cpy_supported) { |
| ggml_backend_sycl_buffer_context * src_ctx = (ggml_backend_sycl_buffer_context *)src->buffer->context; |
| ggml_backend_sycl_buffer_context * dst_ctx = (ggml_backend_sycl_buffer_context *)dst->buffer->context; |
|
|
| ggml_sycl_set_device(src_ctx->device); |
| |
| |
| |
| |
| |
| SYCL_CHECK(CHECK_TRY_ERROR( |
| dpct::dev_mgr::instance().get_device(src_ctx->device).queues_wait_and_throw())); |
| ggml_sycl_set_device(dst_ctx->device); |
| |
| |
| |
| |
| |
| SYCL_CHECK(CHECK_TRY_ERROR( |
| dpct::dev_mgr::instance().get_device(dst_ctx->device).queues_wait_and_throw())); |
| |
| |
| |
| |
| |
|
|
| queue_ptr stream_dst = dst_ctx->stream; |
| queue_ptr stream_src = src_ctx->stream; |
| size_t size = ggml_nbytes(src); |
|
|
| |
| dev2dev_memcpy(*stream_dst, *stream_src, dst->data, src->data, size); |
|
|
| |
| #if 0 |
| SYCL_CHECK(CHECK_TRY_ERROR((*stream).memcpy( |
| (char *)dst->data, (const char *)src->data, size).wait())); |
|
|
| |
| |
| |
| |
| |
| SYCL_CHECK(CHECK_TRY_ERROR( |
| dpct::dev_mgr::instance().get_device(dst_ctx->device).queues_wait_and_throw())); |
| #endif |
| return true; |
| } |
| return false; |
| GGML_UNUSED(buffer); |
| } catch (const sycl::exception & exc) { |
| std::cerr << exc.what() << "Exception caught at file:" << __FILE__ << ", line:" << __LINE__ << std::endl; |
| std::exit(1); |
| } |
|
|
| static void ggml_backend_sycl_buffer_clear(ggml_backend_buffer_t buffer, |
| uint8_t value) try { |
| GGML_SYCL_DEBUG("[SYCL] call %s: size=%zu\n", __func__, buffer->size); |
| ggml_backend_sycl_buffer_context * ctx = (ggml_backend_sycl_buffer_context *) buffer->context; |
|
|
| ggml_sycl_set_device(ctx->device); |
| queue_ptr stream = ctx->stream; |
| SYCL_CHECK( |
| CHECK_TRY_ERROR(dpct::get_current_device().queues_wait_and_throw())); |
|
|
| SYCL_CHECK(CHECK_TRY_ERROR((*stream) |
| .memset(ctx->dev_ptr, value, buffer->size) |
| .wait())); |
| } |
| catch (sycl::exception const &exc) { |
| std::cerr << exc.what() << "Exception caught at file:" << __FILE__ |
| << ", line:" << __LINE__ << std::endl; |
| std::exit(1); |
| } |
|
|
| static void ggml_backend_sycl_buffer_memset_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, uint8_t value, |
| size_t offset, size_t size) { |
| GGML_SYCL_DEBUG("[SYCL] call %s", __func__); |
| GGML_SYCL_DEBUG("%s", debug_get_tensor_str(": tensor", tensor).c_str()); |
| GGML_SYCL_DEBUG(" size=%zu offset=%zu value=%u\n", size, offset, value); |
| ggml_backend_sycl_buffer_context * ctx = (ggml_backend_sycl_buffer_context *) buffer->context; |
| SYCL_CHECK(ggml_sycl_set_device(ctx->device)); |
| auto stream = &(dpct::dev_mgr::instance().get_device(ctx->device).default_queue()); |
| if (size == 0) { |
| return; |
| } |
| if (tensor->data == nullptr) { |
| GGML_ABORT("Error: Tensor data pointer is null.\n"); |
| } |
| void * target_ptr = static_cast<char *>(tensor->data) + offset; |
| SYCL_CHECK(CHECK_TRY_ERROR((*stream).memset(target_ptr, value, size))); |
| SYCL_CHECK(CHECK_TRY_ERROR((*stream).wait())); |
| } |
|
|
| static void ggml_backend_sycl_buffer_reset(ggml_backend_buffer_t buffer) { |
| GGML_SYCL_DEBUG("[SYCL] call %s\n", __func__); |
| if (buffer == nullptr) { |
| return; |
| } |
|
|
| ggml_backend_sycl_buffer_context * ctx = (ggml_backend_sycl_buffer_context *) buffer->context; |
|
|
| if (ctx != nullptr) { |
| for (ggml_tensor_extra_gpu * extra : ctx->tensor_extras) { |
| release_extra_gpu(extra); |
| } |
| ctx->tensor_extras.clear(); |
| } |
| } |
|
|
| static const ggml_backend_buffer_i ggml_backend_sycl_buffer_interface = { |
| ggml_backend_sycl_buffer_free_buffer, |
| ggml_backend_sycl_buffer_get_base, |
| ggml_backend_sycl_buffer_init_tensor, |
| ggml_backend_sycl_buffer_memset_tensor, |
| ggml_backend_sycl_buffer_set_tensor, |
| ggml_backend_sycl_buffer_get_tensor, |
| ggml_backend_sycl_buffer_cpy_tensor, |
| ggml_backend_sycl_buffer_clear, |
| ggml_backend_sycl_buffer_reset, |
| }; |
|
|
| |
| struct ggml_backend_sycl_buffer_type_context { |
| int device; |
| std::string name; |
|
|
| |
| queue_ptr stream = nullptr; |
| }; |
|
|
| static const char * ggml_backend_sycl_buffer_type_get_name(ggml_backend_buffer_type_t buft) { |
| ggml_backend_sycl_buffer_type_context * ctx = (ggml_backend_sycl_buffer_type_context *)buft->context; |
|
|
| return ctx->name.c_str(); |
| } |
|
|
| static ggml_backend_buffer_t |
| ggml_backend_sycl_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, |
| size_t size) try { |
| ggml_backend_sycl_buffer_type_context * buft_ctx = (ggml_backend_sycl_buffer_type_context *)buft->context; |
| ggml_sycl_set_device(buft_ctx->device); |
| const queue_ptr stream = buft_ctx->stream; |
| size = std::max(size, (size_t)1); |
|
|
| void * dev_ptr; |
| SYCL_CHECK(CHECK_TRY_ERROR(dev_ptr = (void *)sycl::malloc_device( |
| size, *stream))); |
| if (!dev_ptr) { |
| GGML_LOG_ERROR("%s: can't allocate %lu Bytes of memory on device\n", __func__, size); |
| return nullptr; |
| } |
| ggml_backend_sycl_buffer_context * ctx = new ggml_backend_sycl_buffer_context(buft_ctx->device, dev_ptr, buft_ctx->stream); |
| return ggml_backend_buffer_init(buft, ggml_backend_sycl_buffer_interface, ctx, size); |
| } |
| catch (sycl::exception const &exc) { |
| std::cerr << exc.what() << "Exception caught at file:" << __FILE__ |
| << ", line:" << __LINE__ << std::endl; |
| std::exit(1); |
| } |
|
|
| static size_t ggml_backend_sycl_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) { |
| return 128; |
| GGML_UNUSED(buft); |
| } |
|
|
| static size_t ggml_backend_sycl_buffer_type_get_max_size(ggml_backend_buffer_type_t buft) { |
| return dpct::get_current_device().get_max_mem_alloc_size(); |
|
|
| GGML_UNUSED(buft); |
| } |
|
|
| static size_t ggml_backend_sycl_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) { |
| size_t size = ggml_nbytes(tensor); |
| int64_t ne0 = tensor->ne[0]; |
|
|
| if (ggml_is_quantized(tensor->type)) { |
| if (ne0 % MATRIX_ROW_PADDING != 0) { |
| size += ggml_row_size(tensor->type, MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING); |
| } |
| } |
|
|
| return size; |
|
|
| GGML_UNUSED(buft); |
| } |
|
|
| static const ggml_backend_buffer_type_i ggml_backend_sycl_buffer_type_interface = { |
| ggml_backend_sycl_buffer_type_get_name, |
| ggml_backend_sycl_buffer_type_alloc_buffer, |
| ggml_backend_sycl_buffer_type_get_alignment, |
| ggml_backend_sycl_buffer_type_get_max_size, |
| ggml_backend_sycl_buffer_type_get_alloc_size, |
| NULL, |
| }; |
|
|
| ggml_backend_buffer_type_t ggml_backend_sycl_buffer_type(int device) { |
| static std::mutex mutex; |
| std::lock_guard<std::mutex> lock(mutex); |
|
|
|
|
| auto dev_count = ggml_backend_sycl_get_device_count(); |
|
|
| if (device>=dev_count or device<0) { |
| GGML_LOG_ERROR("ggml_backend_sycl_buffer_type error: device_index:%d is out of range [0, %d], miss to call ggml_backend_sycl_set_single_device()\n", |
| device, dev_count-1); |
| GGML_ASSERT(device<dev_count); |
| } |
| static struct ggml_backend_buffer_type ggml_backend_sycl_buffer_types[GGML_SYCL_MAX_DEVICES]; |
|
|
| static bool ggml_backend_sycl_buffer_type_initialized = false; |
|
|
| if (!ggml_backend_sycl_buffer_type_initialized) { |
| for (int i = 0; i < dev_count; i++) { |
| auto & device_i = dpct::dev_mgr::instance().get_device(i); |
| queue_ptr stream = &(device_i.default_queue()); |
| ggml_backend_sycl_buffer_types[i] = { |
| ggml_backend_sycl_buffer_type_interface, |
| ggml_backend_reg_dev_get(ggml_backend_sycl_reg(), i), |
| new ggml_backend_sycl_buffer_type_context{i, GGML_SYCL_NAME + std::to_string(i), stream}, |
| }; |
| } |
| ggml_backend_sycl_buffer_type_initialized = true; |
| } |
| return &ggml_backend_sycl_buffer_types[device]; |
| } |
|
|
| static ggml_backend_buffer_type_t ggml_backend_sycl_buffer_type(ggml_backend_sycl_context * ctx) { |
| GGML_SYCL_DEBUG("[SYCL] call ggml_backend_sycl_buffer_type\n"); |
|
|
| int device = ctx->device; |
| if (device>=ggml_sycl_info().device_count or device<0) { |
| GGML_LOG_ERROR("ggml_backend_sycl_buffer_type error: device_index:%d is out of range [0, %d], miss to call ggml_backend_sycl_set_single_device()\n", |
| device, ggml_sycl_info().device_count-1); |
| GGML_ASSERT(device<ggml_sycl_info().device_count); |
| } |
| static struct ggml_backend_buffer_type ggml_backend_sycl_buffer_types[GGML_SYCL_MAX_DEVICES]; |
|
|
| static bool ggml_backend_sycl_buffer_type_initialized = false; |
|
|
| if (!ggml_backend_sycl_buffer_type_initialized) { |
| for (int i = 0; i < ggml_sycl_info().device_count; i++) { |
| ggml_backend_sycl_buffer_types[i] = { |
| ggml_backend_sycl_buffer_type_interface, |
| nullptr, |
| new ggml_backend_sycl_buffer_type_context{i, GGML_SYCL_NAME + std::to_string(i), ctx->stream(i, 0)}, |
| }; |
| } |
| ggml_backend_sycl_buffer_type_initialized = true; |
| } |
| return &ggml_backend_sycl_buffer_types[device]; |
| } |
|
|
| |
|
|
| static int64_t get_row_rounding(ggml_type type, const std::array<float, GGML_SYCL_MAX_DEVICES> & tensor_split) { |
| int64_t min_compute_capability = INT_MAX; |
| int64_t max_compute_capability = INT_MIN; |
| for (int i = 0; i < ggml_sycl_info().device_count; ++i) { |
| if (tensor_split[i] < (i + 1 < ggml_sycl_info().device_count ? tensor_split[i + 1] : 1.0f)) { |
| if (min_compute_capability > ggml_sycl_info().devices[i].cc) { |
| min_compute_capability = ggml_sycl_info().devices[i].cc; |
| } |
| if (max_compute_capability < ggml_sycl_info().devices[i].cc) { |
| max_compute_capability = ggml_sycl_info().devices[i].cc; |
| } |
| } |
| } |
|
|
| switch(type) { |
| case GGML_TYPE_Q4_0: |
| case GGML_TYPE_Q4_1: |
| return max_compute_capability >= VER_GEN9 ? 128 : 64; |
| case GGML_TYPE_Q5_0: |
| case GGML_TYPE_Q5_1: |
| case GGML_TYPE_Q8_0: |
| return 64; |
| case GGML_TYPE_F16: |
| case GGML_TYPE_F32: |
| return 1; |
| case GGML_TYPE_Q2_K: |
| case GGML_TYPE_Q3_K: |
| case GGML_TYPE_Q4_K: |
| case GGML_TYPE_Q5_K: |
| case GGML_TYPE_IQ2_XXS: |
| case GGML_TYPE_IQ2_XS: |
| case GGML_TYPE_IQ2_S: |
| case GGML_TYPE_IQ1_S: |
| case GGML_TYPE_IQ1_M: |
| case GGML_TYPE_IQ3_XXS: |
| case GGML_TYPE_IQ4_XS: |
| case GGML_TYPE_IQ4_NL: |
| return max_compute_capability >= VER_GEN9 ? 128 : 64; |
| case GGML_TYPE_IQ3_S: |
| return max_compute_capability >= VER_GEN9 ? 128 : 64; |
| case GGML_TYPE_Q6_K: |
| return 64; |
| default: |
| GGML_ABORT("fatal error"); |
| } |
| } |
|
|
| static void get_row_split(int64_t * row_low, int64_t * row_high, const ggml_tensor * tensor, const std::array<float, GGML_SYCL_MAX_DEVICES> & tensor_split, int id) { |
| const int64_t nrows = ggml_nrows(tensor); |
| const int64_t rounding = get_row_rounding(tensor->type, tensor_split); |
|
|
| *row_low = id == 0 ? 0 : nrows*tensor_split[id]; |
| *row_low -= *row_low % rounding; |
| if (id == ggml_sycl_info().device_count - 1) { |
| *row_high = nrows; |
| } else { |
| *row_high = nrows*tensor_split[id + 1]; |
| *row_high -= *row_high % rounding; |
| } |
| } |
|
|
| static size_t ggml_nbytes_split(const struct ggml_tensor * tensor, int nrows_split) { |
| static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function"); |
|
|
| return nrows_split*ggml_row_size(tensor->type, tensor->ne[0]); |
| } |
|
|
| struct ggml_backend_sycl_split_buffer_type_context { |
| std::array<float, GGML_SYCL_MAX_DEVICES> tensor_split; |
| }; |
|
|
| struct ggml_backend_sycl_split_buffer_context { |
| ~ggml_backend_sycl_split_buffer_context() try { |
| for (ggml_tensor_extra_gpu * extra : tensor_extras) { |
| release_extra_gpu(extra, streams); |
| } |
| } |
| catch (sycl::exception const &exc) { |
| std::cerr << exc.what() << "Exception caught at file:" << __FILE__ |
| << ", line:" << __LINE__ << std::endl; |
| std::exit(1); |
| } |
|
|
| std::vector<ggml_tensor_extra_gpu *> tensor_extras; |
| std::vector<queue_ptr> streams; |
| }; |
|
|
| static void ggml_backend_sycl_split_buffer_free_buffer(ggml_backend_buffer_t buffer) { |
| ggml_backend_sycl_split_buffer_context * ctx = (ggml_backend_sycl_split_buffer_context *)buffer->context; |
| delete ctx; |
| } |
|
|
| static void * ggml_backend_sycl_split_buffer_get_base(ggml_backend_buffer_t buffer) { |
| |
| return (void *)0x1000; |
|
|
| GGML_UNUSED(buffer); |
| } |
|
|
| static enum ggml_status |
| ggml_backend_sycl_split_buffer_init_tensor(ggml_backend_buffer_t buffer, |
| ggml_tensor *tensor) try { |
| GGML_SYCL_DEBUG("[SYCL] call %s", __func__); |
| GGML_SYCL_DEBUG("%s", debug_get_tensor_str(": tensor", tensor, "\n").c_str()); |
| GGML_ASSERT(tensor->view_src == nullptr); |
|
|
| ggml_backend_sycl_split_buffer_context * ctx = (ggml_backend_sycl_split_buffer_context *)buffer->context; |
| ggml_backend_sycl_split_buffer_type_context * buft_ctx = (ggml_backend_sycl_split_buffer_type_context *)buffer->buft->context; |
|
|
| const int64_t ne0 = tensor->ne[0]; |
|
|
| ggml_tensor_extra_gpu * extra = new ggml_tensor_extra_gpu{}; |
|
|
| ctx->tensor_extras.push_back(extra); |
| ctx->streams.push_back(&(dpct::get_current_device().default_queue())); |
|
|
| for (int i = 0; i < ggml_sycl_info().device_count; ++i) { |
| int64_t row_low, row_high; |
| get_row_split(&row_low, &row_high, tensor, buft_ctx->tensor_split, i); |
|
|
| int64_t nrows_split = row_high - row_low; |
| if (nrows_split == 0) { |
| continue; |
| } |
|
|
| size_t size = ggml_nbytes_split(tensor, nrows_split); |
| const size_t original_size = size; |
|
|
| |
| if (ne0 % MATRIX_ROW_PADDING != 0) { |
| size += ggml_row_size(tensor->type, MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING); |
| } |
|
|
| |
| |
| ggml_sycl_set_device(i); |
| const queue_ptr stream = ctx->streams[i]; |
| char * buf; |
| |
| |
| |
| |
| |
| SYCL_CHECK(CHECK_TRY_ERROR(buf = (char *)sycl::malloc_device( |
| size, *stream))); |
| if (!buf) { |
| char err_buf[1024]; |
| snprintf(err_buf, 1023, "%s: can't allocate %lu Bytes of memory on device\n", __func__, size); |
| throw std::runtime_error(err_buf); |
| } |
| |
| if (size > original_size) { |
| |
| |
| |
| |
| |
| SYCL_CHECK(CHECK_TRY_ERROR( |
| (*stream) |
| .memset(buf + original_size, 0, size - original_size) |
| .wait())); |
| } |
|
|
| extra->data_device[i] = buf; |
|
|
| for (int64_t is = 0; is < GGML_SYCL_MAX_STREAMS; ++is) { |
| |
| |
| |
| |
| |
| SYCL_CHECK( |
| CHECK_TRY_ERROR(extra->events[i][is] = new sycl::event())); |
| } |
| } |
| tensor->extra = extra; |
| return GGML_STATUS_SUCCESS; |
| } |
| catch (sycl::exception const &exc) { |
| std::cerr << exc.what() << "Exception caught at file:" << __FILE__ |
| << ", line:" << __LINE__ << std::endl; |
| std::exit(1); |
| } |
|
|
| static void |
| ggml_backend_sycl_split_buffer_set_tensor(ggml_backend_buffer_t buffer, |
| ggml_tensor *tensor, const void *data, |
| size_t offset, size_t size) try { |
| GGML_SYCL_DEBUG("[SYCL] call %s", __func__); |
| GGML_SYCL_DEBUG("%s", debug_get_tensor_str(": tensor", tensor).c_str()); |
| GGML_SYCL_DEBUG(" size=%zu offset=%zu\n", size, offset); |
| |
| GGML_ASSERT(offset == 0); |
| GGML_ASSERT(size == ggml_nbytes(tensor)); |
|
|
| ggml_backend_sycl_split_buffer_context * ctx = (ggml_backend_sycl_split_buffer_context *)buffer->context; |
| ggml_backend_sycl_split_buffer_type_context * buft_ctx = (ggml_backend_sycl_split_buffer_type_context *)buffer->buft->context; |
|
|
| const int64_t ne0 = tensor->ne[0]; |
| const size_t nb1 = tensor->nb[1]; |
| ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *)tensor->extra; |
|
|
| for (int i = 0; i < ggml_sycl_info().device_count; ++i) { |
| int64_t row_low, row_high; |
| get_row_split(&row_low, &row_high, tensor, buft_ctx->tensor_split, i); |
|
|
| int64_t nrows_split = row_high - row_low; |
| if (nrows_split == 0) { |
| continue; |
| } |
|
|
| const size_t offset_split = row_low*nb1; |
| size_t size = ggml_nbytes_split(tensor, nrows_split); |
| const size_t original_size = size; |
|
|
| |
| if (ne0 % MATRIX_ROW_PADDING != 0) { |
| size += ggml_row_size(tensor->type, MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING); |
| } |
|
|
| const char * buf_host = (const char *)data + offset_split; |
| |
| |
| |
| |
| |
| ggml_sycl_set_device(i); |
| const queue_ptr stream = ctx->streams[i]; |
| SYCL_CHECK(CHECK_TRY_ERROR( |
| (*stream) |
| .memcpy(extra->data_device[i], buf_host, original_size) |
| .wait())); |
| } |
| } |
| catch (sycl::exception const &exc) { |
| std::cerr << exc.what() << "Exception caught at file:" << __FILE__ |
| << ", line:" << __LINE__ << std::endl; |
| std::exit(1); |
| } |
|
|
| static void |
| ggml_backend_sycl_split_buffer_get_tensor(ggml_backend_buffer_t buffer, |
| const ggml_tensor *tensor, void *data, |
| size_t offset, size_t size) try { |
| GGML_SYCL_DEBUG("[SYCL] call %s", __func__); |
| GGML_SYCL_DEBUG("%s", debug_get_tensor_str(": tensor", tensor).c_str()); |
| GGML_SYCL_DEBUG(" size=%zu offset=%zu\n", size, offset); |
| |
| GGML_ASSERT(offset == 0); |
| GGML_ASSERT(size == ggml_nbytes(tensor)); |
|
|
| ggml_backend_sycl_split_buffer_context * ctx = (ggml_backend_sycl_split_buffer_context *)buffer->context; |
| ggml_backend_sycl_split_buffer_type_context * buft_ctx = (ggml_backend_sycl_split_buffer_type_context *)buffer->buft->context; |
|
|
| const int64_t ne0 = tensor->ne[0]; |
| const size_t nb1 = tensor->nb[1]; |
| ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *)tensor->extra; |
|
|
| for (int i = 0; i < ggml_sycl_info().device_count; ++i) { |
| int64_t row_low, row_high; |
| get_row_split(&row_low, &row_high, tensor, buft_ctx->tensor_split, i); |
|
|
| int64_t nrows_split = row_high - row_low; |
| if (nrows_split == 0) { |
| continue; |
| } |
|
|
| const size_t offset_split = row_low*nb1; |
| size_t size = ggml_nbytes_split(tensor, nrows_split); |
| const size_t original_size = size; |
|
|
| |
| if (ne0 % MATRIX_ROW_PADDING != 0) { |
| size += ggml_row_size(tensor->type, MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING); |
| } |
|
|
| char * buf_host = (char *)data + offset_split; |
| |
| |
| |
| |
| |
| ggml_sycl_set_device(i); |
| const queue_ptr stream = ctx->streams[i]; |
| SYCL_CHECK(CHECK_TRY_ERROR( |
| (*stream) |
| .memcpy(buf_host, extra->data_device[i], original_size) |
| .wait())); |
| } |
| } |
| catch (sycl::exception const &exc) { |
| std::cerr << exc.what() << "Exception caught at file:" << __FILE__ |
| << ", line:" << __LINE__ << std::endl; |
| std::exit(1); |
| } |
|
|
| static void ggml_backend_sycl_split_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) { |
| GGML_UNUSED(buffer); |
| GGML_UNUSED(value); |
| } |
|
|
| static struct ggml_backend_buffer_i ggml_backend_sycl_split_buffer_interface = { |
| ggml_backend_sycl_split_buffer_free_buffer, |
| ggml_backend_sycl_split_buffer_get_base, |
| ggml_backend_sycl_split_buffer_init_tensor, |
| NULL, |
| ggml_backend_sycl_split_buffer_set_tensor, |
| ggml_backend_sycl_split_buffer_get_tensor, |
| NULL, |
| ggml_backend_sycl_split_buffer_clear, |
| NULL, |
| }; |
|
|
| |
|
|
| static const char * ggml_backend_sycl_split_buffer_type_get_name(ggml_backend_buffer_type_t buft) { |
| return GGML_SYCL_NAME "_Split"; |
|
|
| GGML_UNUSED(buft); |
| } |
|
|
| static bool ggml_backend_buffer_is_sycl_split(ggml_backend_buffer_t buffer) { |
| return buffer->buft->iface.get_name == ggml_backend_sycl_split_buffer_type_get_name; |
| } |
|
|
| static ggml_backend_buffer_t ggml_backend_sycl_split_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) { |
| |
| |
| |
| |
| ggml_backend_sycl_split_buffer_context * ctx = new ggml_backend_sycl_split_buffer_context(); |
|
|
| return ggml_backend_buffer_init(buft, ggml_backend_sycl_split_buffer_interface, ctx, size); |
| } |
|
|
| static size_t ggml_backend_sycl_split_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) { |
| return 128; |
| GGML_UNUSED(buft); |
| } |
|
|
| static size_t ggml_backend_sycl_split_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) { |
| ggml_backend_sycl_split_buffer_type_context * ctx = (ggml_backend_sycl_split_buffer_type_context *)buft->context; |
|
|
| size_t total_size = 0; |
|
|
| const int64_t ne0 = tensor->ne[0]; |
|
|
| for (int i = 0; i < ggml_sycl_info().device_count; ++i) { |
| int64_t row_low, row_high; |
| get_row_split(&row_low, &row_high, tensor, ctx->tensor_split, i); |
|
|
| int64_t nrows_split = row_high - row_low; |
| if (nrows_split == 0) { |
| continue; |
| } |
|
|
| total_size += ggml_nbytes_split(tensor, nrows_split); |
|
|
| |
| if (ne0 % MATRIX_ROW_PADDING != 0) { |
| total_size += ggml_row_size(tensor->type, MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING); |
| } |
| } |
|
|
| return total_size; |
| } |
|
|
| static bool ggml_backend_sycl_split_buffer_type_is_host(ggml_backend_buffer_type_t buft) { |
| return false; |
|
|
| GGML_UNUSED(buft); |
| } |
|
|
| static ggml_backend_buffer_type_i ggml_backend_sycl_split_buffer_type_interface = { |
| ggml_backend_sycl_split_buffer_type_get_name, |
| ggml_backend_sycl_split_buffer_type_alloc_buffer, |
| ggml_backend_sycl_split_buffer_type_get_alignment, |
| NULL, |
| ggml_backend_sycl_split_buffer_type_get_alloc_size, |
| ggml_backend_sycl_split_buffer_type_is_host, |
| }; |
|
|
| ggml_backend_buffer_type_t ggml_backend_sycl_split_buffer_type(const float * tensor_split) { |
| static std::mutex mutex; |
| std::lock_guard<std::mutex> lock(mutex); |
|
|
| GGML_SYCL_DEBUG("[SYCL] call ggml_backend_sycl_split_buffer_type\n"); |
| ggml_check_sycl(); |
| |
| static std::map<std::array<float, GGML_SYCL_MAX_DEVICES>, struct ggml_backend_buffer_type> buft_map; |
|
|
| std::array<float, GGML_SYCL_MAX_DEVICES> tensor_split_arr = {}; |
|
|
| bool all_zero = tensor_split == nullptr || std::all_of(tensor_split, tensor_split + GGML_SYCL_MAX_DEVICES, [](float x) { return x == 0.0f; }); |
| if (all_zero) { |
| tensor_split_arr = ggml_sycl_info().default_tensor_split; |
| } else { |
| float split_sum = 0.0f; |
| for (int i = 0; i < ggml_sycl_info().device_count; ++i) { |
| tensor_split_arr[i] = split_sum; |
| split_sum += tensor_split[i]; |
| } |
| for (int i = 0; i < ggml_sycl_info().device_count; ++i) { |
| tensor_split_arr[i] /= split_sum; |
| } |
| } |
|
|
| auto it = buft_map.find(tensor_split_arr); |
| if (it != buft_map.end()) { |
| return &it->second; |
| } |
|
|
| struct ggml_backend_buffer_type buft { |
| ggml_backend_sycl_split_buffer_type_interface, |
| ggml_backend_reg_dev_get(ggml_backend_sycl_reg(), 0), |
| new ggml_backend_sycl_split_buffer_type_context{tensor_split_arr}, |
| }; |
|
|
| auto result = buft_map.emplace(tensor_split_arr, buft); |
| return &result.first->second; |
| } |
|
|
| |
|
|
| static const char * ggml_backend_sycl_host_buffer_type_name(ggml_backend_buffer_type_t buft) { |
| return GGML_SYCL_NAME "_Host"; |
|
|
| GGML_UNUSED(buft); |
| } |
|
|
| inline void * aligned_malloc_host(size_t alignment, size_t size) { |
| #ifdef _WIN32 |
| return _aligned_malloc(size, alignment); |
| #else |
| return aligned_alloc(alignment, size); |
| #endif |
| } |
|
|
| inline void free_aligned_mem_host(void * memblock) { |
| #ifdef _WIN32 |
| _aligned_free(memblock); |
| #else |
| free(memblock); |
| #endif |
| } |
|
|
| static void ggml_backend_sycl_host_buffer_free_buffer(ggml_backend_buffer_t buffer) { |
| free_aligned_mem_host((void *)buffer->context); |
| } |
|
|
| static ggml_backend_buffer_t ggml_backend_sycl_host_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) { |
| void * ptr = aligned_malloc_host(TENSOR_ALIGNMENT, size); |
| if (ptr == nullptr) { |
| |
| return ggml_backend_buft_alloc_buffer(ggml_backend_cpu_buffer_type(), size); |
| } |
|
|
| |
| ggml_backend_buffer_t buffer = ggml_backend_cpu_buffer_from_ptr(ptr, size); |
| buffer->buft = buft; |
| buffer->iface.free_buffer = ggml_backend_sycl_host_buffer_free_buffer; |
|
|
| return buffer; |
| } |
|
|
| ggml_backend_buffer_type_t ggml_backend_sycl_host_buffer_type() { |
| GGML_SYCL_DEBUG("[SYCL] call ggml_backend_sycl_host_buffer_type\n"); |
| static struct ggml_backend_buffer_type ggml_backend_sycl_buffer_type_host = { |
| { |
| ggml_backend_sycl_host_buffer_type_name, |
| ggml_backend_sycl_host_buffer_type_alloc_buffer, |
| ggml_backend_cpu_buffer_type()->iface.get_alignment, |
| NULL, |
| ggml_backend_cpu_buffer_type()->iface.get_alloc_size, |
| ggml_backend_cpu_buffer_type()->iface.is_host, |
| }, |
| ggml_backend_reg_dev_get(ggml_backend_sycl_reg(), 0), |
| nullptr, |
| }; |
|
|
| return &ggml_backend_sycl_buffer_type_host; |
| } |
|
|
| |
| struct ggml_sycl_pool_leg : public ggml_sycl_pool { |
| static const int MAX_SYCL_BUFFERS = 256; |
|
|
| int device; |
| queue_ptr qptr; |
| struct ggml_sycl_buffer { |
| void * ptr = nullptr; |
| size_t size = 0; |
| }; |
|
|
| ggml_sycl_buffer buffer_pool[MAX_SYCL_BUFFERS] = {}; |
| size_t pool_size = 0; |
|
|
| explicit ggml_sycl_pool_leg(queue_ptr qptr_, int device_) : device(device_), qptr(qptr_) {} |
|
|
| ~ggml_sycl_pool_leg() { |
| for (int i = 0; i < MAX_SYCL_BUFFERS; ++i) { |
| ggml_sycl_buffer & b = buffer_pool[i]; |
| if (b.ptr != nullptr) { |
| SYCL_CHECK(CHECK_TRY_ERROR(sycl::free(b.ptr, *qptr))); |
| pool_size -= b.size; |
| } |
| } |
| GGML_ASSERT(pool_size == 0); |
| } |
|
|
| void * alloc(size_t size, size_t * actual_size) override { |
| #ifdef DEBUG_sycl_MALLOC |
| int nnz = 0; |
| size_t max_size = 0; |
| #endif |
| size_t best_diff = 1ull << 36; |
| int ibest = -1; |
| for (int i = 0; i < MAX_SYCL_BUFFERS; ++i) { |
| ggml_sycl_buffer& b = buffer_pool[i]; |
| if (b.ptr != nullptr) { |
| #ifdef DEBUG_sycl_MALLOC |
| ++nnz; |
| if (b.size > max_size) max_size = b.size; |
| #endif |
| if (b.size >= size) { |
| size_t diff = b.size - size; |
| if (diff < best_diff) { |
| best_diff = diff; |
| ibest = i; |
| if (!best_diff) { |
| void * ptr = b.ptr; |
| *actual_size = b.size; |
| b.ptr = nullptr; |
| b.size = 0; |
| return ptr; |
| } |
| } |
| } |
| } |
| } |
| if (ibest >= 0) { |
| ggml_sycl_buffer& b = buffer_pool[ibest]; |
| void * ptr = b.ptr; |
| *actual_size = b.size; |
| b.ptr = nullptr; |
| b.size = 0; |
| return ptr; |
| } |
| void * ptr; |
| size_t look_ahead_size = (size_t) (1.05 * size); |
|
|
| SYCL_CHECK( |
| CHECK_TRY_ERROR(ptr = (void *)sycl::malloc_device( |
| look_ahead_size, *qptr))); |
| if (!ptr) { |
| GGML_LOG_ERROR("%s: can't allocate %lu Bytes of memory on device/GPU\n", __func__, look_ahead_size); |
| return nullptr; |
| } |
|
|
| *actual_size = look_ahead_size; |
| pool_size += look_ahead_size; |
|
|
| #ifdef DEBUG_SYCL_MALLOC |
| GGML_LOG_DEBUG("%s[%d]: %d buffers, max_size = %u MB, pool_size = %u MB, requested %u MB\n", __func__, id, nnz, |
| (uint32_t)(max_size/1024/1024), (uint32_t)(g_sycl_pool_size[id]/1024/1024), (uint32_t)(size/1024/1024)); |
| #endif |
|
|
| |
| return ptr; |
| } |
|
|
| void free(void * ptr, size_t size) override { |
| for (int i = 0; i < MAX_SYCL_BUFFERS; ++i) { |
| ggml_sycl_buffer& b = buffer_pool[i]; |
| if (b.ptr == nullptr) { |
| b.ptr = ptr; |
| b.size = size; |
| return; |
| } |
| } |
| GGML_LOG_WARN("WARNING: sycl buffer pool full, increase MAX_sycl_BUFFERS\n"); |
| SYCL_CHECK(CHECK_TRY_ERROR(sycl::free(ptr, *qptr))); |
| pool_size -= size; |
| } |
| }; |
|
|
| struct ggml_sycl_pool_host : public ggml_sycl_pool { |
| queue_ptr qptr; |
| int device; |
|
|
| inline static int counter{ 0 }; |
|
|
| struct ggml_sycl_buffer { |
| void * ptr = nullptr; |
| size_t size = 0; |
| }; |
|
|
| |
| static constexpr int MAX_POOL_SIZE{ 64 }; |
| std::vector<ggml_sycl_buffer> buffer_pool = std::vector<ggml_sycl_buffer>(MAX_POOL_SIZE); |
| size_t pool_size = 0; |
|
|
| explicit ggml_sycl_pool_host(queue_ptr qptr_, int device_) : qptr(qptr_), device(device_) {} |
|
|
| ~ggml_sycl_pool_host() { |
| for (int i = 0; i < MAX_POOL_SIZE; ++i) { |
| ggml_sycl_buffer & b = buffer_pool[i]; |
| if (b.ptr != nullptr) { |
| SYCL_CHECK(CHECK_TRY_ERROR(sycl::free(b.ptr, *qptr))); |
| b.ptr = nullptr; |
| pool_size -= b.size; |
| b.size = 0; |
| } |
| } |
| counter = 0; |
| } |
|
|
| void * alloc(size_t size, size_t * actual_size) override { |
| if (counter == MAX_POOL_SIZE) { |
| ggml_sycl_buffer b = buffer_pool[0]; |
| void * ptr = b.ptr; |
| *actual_size = b.size; |
| counter = 1; |
| return ptr; |
| } |
| ggml_sycl_buffer & b = buffer_pool[counter]; |
|
|
| if (b.ptr == nullptr) { |
| void * ptr; |
|
|
| SYCL_CHECK(CHECK_TRY_ERROR(ptr = (void *) sycl::malloc_host(size, *qptr))); |
| if (!ptr) { |
| GGML_LOG_ERROR("%s: can't allocate %lu Bytes of memory on host\n", __func__, size); |
| return nullptr; |
| } |
| pool_size += size; |
| *actual_size = size; |
| counter = counter + 1; |
| return ptr; |
| } else { |
| ++counter; |
| b.size = size; |
| return b.ptr; |
| } |
| } |
|
|
| void free(void * ptr, size_t size) override { |
| |
| |
| for (int i = 0; i < MAX_POOL_SIZE; ++i) { |
| ggml_sycl_buffer & b = buffer_pool[i]; |
| if (b.ptr == nullptr) { |
| b.ptr = ptr; |
| b.size = size; |
| return; |
| } |
| } |
| } |
| }; |
|
|
| std::unique_ptr<ggml_sycl_pool> ggml_backend_sycl_context::new_pool_for_host(queue_ptr qptr, int device) { |
| |
| return std::unique_ptr<ggml_sycl_pool>(new ggml_sycl_pool_host(qptr, device)); |
| } |
|
|
| std::unique_ptr<ggml_sycl_pool> ggml_backend_sycl_context::new_pool_for_device(queue_ptr qptr, int device) { |
| |
| |
| |
| |
| return std::unique_ptr<ggml_sycl_pool>(new ggml_sycl_pool_leg(qptr, device)); |
| } |
|
|
| |
| |
|
|
| |
| typedef void (*ggml_sycl_op_mul_mat_t)( |
| ggml_backend_sycl_context & ctx, |
| const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst, |
| const char *src0_dd_i, const float *src1_ddf_i, const char *src1_ddq_i, |
| float *dst_dd_i, const int64_t row_low, const int64_t row_high, |
| const int64_t src1_ncols, const int64_t src1_padded_row_size, |
| const queue_ptr &stream); |
|
|
|
|
|
|
| static void mul_mat_p021_f16_f32( |
| const void * __restrict__ vx, const float * __restrict__ y, float * __restrict__ dst, |
| const int ncols_x, const int nrows_x, const int nchannels_x, const int nchannels_y, |
| const sycl::nd_item<3> &item_ct1) { |
|
|
| const sycl::half *x = (const sycl::half *)vx; |
|
|
| const int row_x = item_ct1.get_local_range(1) * item_ct1.get_group(1) + |
| item_ct1.get_local_id(1); |
| const int channel = item_ct1.get_local_range(0) * item_ct1.get_group(0) + |
| item_ct1.get_local_id(0); |
| const int channel_x = channel / (nchannels_y / nchannels_x); |
|
|
| const int nrows_y = ncols_x; |
| const int nrows_dst = nrows_x; |
| const int row_dst = row_x; |
|
|
| float tmp = 0.0f; |
|
|
| for (int col_x0 = 0; col_x0 < ncols_x; |
| col_x0 += item_ct1.get_local_range(2)) { |
| const int col_x = col_x0 + item_ct1.get_local_id(2); |
|
|
| if (col_x >= ncols_x) { |
| break; |
| } |
|
|
| |
| const int ix = row_x*nchannels_x*ncols_x + channel_x*ncols_x + col_x; |
| const float xi = |
| sycl::vec<sycl::half, 1>(x[ix]) |
| .convert<float, sycl::rounding_mode::automatic>()[0]; |
|
|
| const int row_y = col_x; |
|
|
|
|
| |
| const int iy = channel*nrows_y + row_y; |
|
|
| tmp += xi * y[iy]; |
| } |
|
|
| |
| const int idst = channel*nrows_dst + row_dst; |
|
|
| |
| #pragma unroll |
| for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) { |
| tmp += |
| dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask); |
| } |
|
|
| if (item_ct1.get_local_id(2) == 0) { |
| dst[idst] = tmp; |
| } |
| } |
|
|
| static void mul_mat_vec_nc_f16_f32( |
| const void * __restrict__ vx, const float * __restrict__ y, float * __restrict__ dst, const int ncols_x, const int nrows_x, |
| const int row_stride_x, const int channel_stride_x,const int channel_stride_y, const int channel_x_divisor, |
| const sycl::nd_item<3> &item_ct1) { |
|
|
| const sycl::half *x = (const sycl::half *)vx; |
|
|
| const int row_x = item_ct1.get_local_range(1) * item_ct1.get_group(1) + |
| item_ct1.get_local_id(1); |
| const int channel = item_ct1.get_local_range(0) * item_ct1.get_group(0) + |
| item_ct1.get_local_id(0); |
| const int channel_x = channel / channel_x_divisor; |
|
|
| const int nrows_dst = nrows_x; |
| const int row_dst = row_x; |
|
|
| const int idst = channel*nrows_dst + row_dst; |
|
|
| float tmp = 0.0f; |
|
|
| for (int col_x0 = 0; col_x0 < ncols_x; |
| col_x0 += item_ct1.get_local_range(2)) { |
| const int col_x = col_x0 + item_ct1.get_local_id(2); |
|
|
| if (col_x >= ncols_x) { |
| break; |
| } |
|
|
| const int row_y = col_x; |
|
|
| const int ix = channel_x*channel_stride_x + row_x*row_stride_x + col_x; |
| const int iy = channel * channel_stride_y + row_y; |
|
|
| const float xi = |
| sycl::vec<sycl::half, 1>(x[ix]) |
| .convert<float, sycl::rounding_mode::automatic>()[0]; |
|
|
| tmp += xi * y[iy]; |
| } |
|
|
| |
| #pragma unroll |
| for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) { |
| tmp += |
| dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask); |
| } |
|
|
| if (item_ct1.get_local_id(2) == 0) { |
| dst[idst] = tmp; |
| } |
| } |
|
|
| static void k_sum_rows_f32(const float * x, float * dst, const int ncols, |
| const sycl::nd_item<3> &item_ct1) { |
| const int row = item_ct1.get_group(1); |
| const int col = item_ct1.get_local_id(2); |
|
|
| float sum = 0.0f; |
| for (int i = col; i < ncols; i += item_ct1.get_local_range(2)) { |
| sum += x[row * ncols + i]; |
| } |
|
|
| sum = warp_reduce_sum(sum, item_ct1); |
|
|
| if (col == 0) { |
| dst[row] = sum; |
| } |
| } |
|
|
|
|
| template<typename T> |
| static inline void ggml_sycl_swap(T & a, T & b) { |
| T tmp = a; |
| a = b; |
| b = tmp; |
| } |
|
|
| template <ggml_sort_order order> |
| __dpct_inline__ static void |
| k_argsort_f32_i32(const float *x, int *dst, const int ncols, int ncols_pad, |
| const int tasks_per_thread, const sycl::nd_item<3> &item_ct1, |
| uint8_t *dpct_local) { |
| |
| int col_index = item_ct1.get_local_id(2); |
| int row = item_ct1.get_group(1); |
|
|
| for (int i = 0; i < tasks_per_thread; i++) { |
| int col = col_index * tasks_per_thread + i; |
| if (col >= ncols_pad) { |
| return; |
| } |
| } |
|
|
| const float * x_row = x + row * ncols; |
| auto dst_row = (int *)dpct_local; |
|
|
| |
| for (int i=0;i<tasks_per_thread;i++){ |
| int col = col_index*tasks_per_thread+i; |
| dst_row[col] = col; |
| } |
|
|
| item_ct1.barrier(sycl::access::fence_space::local_space); |
|
|
| for (int k = 2; k <= ncols_pad; k *= 2) { |
| for (int j = k / 2; j > 0; j /= 2) { |
| for (int i = 0; i < tasks_per_thread; i++) { |
| int col = col_index * tasks_per_thread + i; |
| int ixj = col ^ j; |
| if (ixj > col) { |
| if ((col & k) == 0) { |
| if (dst_row[col] >= ncols || |
| (dst_row[ixj] < ncols && |
| (order == GGML_SORT_ORDER_ASC |
| ? x_row[dst_row[col]] > x_row[dst_row[ixj]] |
| : x_row[dst_row[col]] < |
| x_row[dst_row[ixj]]))) { |
| ggml_sycl_swap(dst_row[col], dst_row[ixj]); |
| } |
| } else { |
| if (dst_row[ixj] >= ncols || |
| (dst_row[col] < ncols && |
| (order == GGML_SORT_ORDER_ASC |
| ? x_row[dst_row[col]] < x_row[dst_row[ixj]] |
| : x_row[dst_row[col]] > |
| x_row[dst_row[ixj]]))) { |
| ggml_sycl_swap(dst_row[col], dst_row[ixj]); |
| } |
| } |
| } |
| item_ct1.barrier(sycl::access::fence_space::local_space); |
| } |
| } |
| } |
|
|
| |
| for (int i = 0; i < tasks_per_thread; i++) { |
| int col = col_index * tasks_per_thread + i; |
| if (col < ncols) { |
| dst[row * ncols + col] = dst_row[col]; |
| } |
| } |
| } |
|
|
| static void diag_mask_inf_f32(const float * x, float * dst, const int ncols, const int rows_per_channel, const int n_past, |
| const sycl::nd_item<3> &item_ct1) { |
| const int col = item_ct1.get_local_range(1) * item_ct1.get_group(1) + |
| item_ct1.get_local_id(1); |
| const int row = item_ct1.get_local_range(2) * item_ct1.get_group(2) + |
| item_ct1.get_local_id(2); |
|
|
| if (col >= ncols) { |
| return; |
| } |
|
|
| const int i = row*ncols + col; |
| |
| |
| dst[i] = x[i] - (col > n_past + row % rows_per_channel) * FLT_MAX; |
| } |
|
|
| static void scale_f32(const float * x, float * dst, const float scale, const float bias, const int k, |
| const sycl::nd_item<3> &item_ct1) { |
| const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + |
| item_ct1.get_local_id(2); |
|
|
| if (i >= k) { |
| return; |
| } |
|
|
| dst[i] = scale * x[i] + bias; |
| } |
|
|
|
|
| template <typename Ti, typename To> |
| static void pool2d_nchw_kernel( |
| const int ih, const int iw, const int oh, const int ow, |
| const int kh, const int kw, const int sh, const int sw, |
| const int ph, const int pw, const int parallel_elements, |
| const Ti* src, To* dst, const enum ggml_op_pool op, |
| const sycl::nd_item<3> &item_ct1) { |
| int idx = item_ct1.get_local_id(2) + |
| item_ct1.get_group(2) * item_ct1.get_local_range(2); |
| if (idx >= parallel_elements) { |
| return; |
| } |
|
|
| const int I_HW = ih * iw; |
| const int O_HW = oh * ow; |
| const int nc = idx / O_HW; |
| const int cur_oh = idx % O_HW / ow; |
| const int cur_ow = idx % O_HW % ow; |
| const Ti* i_ptr = src + nc * I_HW; |
| To* o_ptr = dst + nc * O_HW; |
| const int start_h = cur_oh * sh - ph; |
| const int bh = sycl::max(0, start_h); |
| const int eh = sycl::min(ih, start_h + kh); |
| const int start_w = cur_ow * sw - pw; |
| const int bw = sycl::max(0, start_w); |
| const int ew = sycl::min(iw, start_w + kw); |
|
|
| To res = 0; |
|
|
| switch (op) { |
| case GGML_OP_POOL_AVG: res = 0; break; |
| case GGML_OP_POOL_MAX: res = -FLT_MAX; break; |
| default: |
| res = (To) sycl::nan(uint32_t(0)); |
| break; |
| } |
|
|
| for (int i = bh; i < eh; i += 1) { |
| for (int j = bw; j < ew; j += 1) { |
| #if DPCT_COMPATIBILITY_TEMP >= 350 |
| |
| |
| |
| |
| |
| |
| Ti cur = *(i_ptr + i * iw + j); |
| #else |
| Ti cur = i_ptr[i * iw + j]; |
| #endif |
| switch (op) { |
| case GGML_OP_POOL_AVG: res += (cur / (kh * kw)); break; |
| case GGML_OP_POOL_MAX: res = sycl::max(res, (To)cur); break; |
| default: |
| res = (To) sycl::nan(uint32_t(0)); |
| break; |
| } |
| } |
| } |
| o_ptr[cur_oh * ow + cur_ow] = res; |
| } |
|
|
|
|
| static void ggml_mul_mat_p021_f16_f32_sycl(const void *vx, const float *y, |
| float *dst, const int ncols_x, |
| const int nrows_x, |
| const int nchannels_x, |
| const int nchannels_y, |
| queue_ptr stream) { |
|
|
| const sycl::range<3> block_nums(nchannels_y, nrows_x, 1); |
| const sycl::range<3> block_dims(1, 1, WARP_SIZE); |
| { |
| dpct::has_capability_or_fail(stream->get_device(), |
| {sycl::aspect::fp16}); |
|
|
| stream->parallel_for( |
| sycl::nd_range<3>(block_nums * block_dims, block_dims), |
| [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { |
| mul_mat_p021_f16_f32(vx, y, dst, ncols_x, nrows_x, nchannels_x, |
| nchannels_y, item_ct1); |
| }); |
| } |
| } |
|
|
| static void ggml_mul_mat_vec_nc_f16_f32_sycl( |
| const void *vx, const float *y, float *dst, const int ncols_x, |
| const int nrows_x, const int row_stride_x, const int nchannels_x, |
| const int nchannels_y, const int channel_stride_x, const int channel_stride_y, queue_ptr stream) { |
|
|
| const sycl::range<3> block_nums(nchannels_y, nrows_x, 1); |
| const sycl::range<3> block_dims(1, 1, WARP_SIZE); |
| { |
| dpct::has_capability_or_fail(stream->get_device(), |
| {sycl::aspect::fp16}); |
|
|
| stream->parallel_for( |
| sycl::nd_range<3>(block_nums * block_dims, block_dims), |
| [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { |
| mul_mat_vec_nc_f16_f32(vx, y, dst, ncols_x, nrows_x, |
| row_stride_x, channel_stride_x, channel_stride_y, |
| nchannels_y / nchannels_x, item_ct1); |
| }); |
| } |
| } |
|
|
|
|
|
|
| static void scale_f32_sycl(const float *x, float *dst, const float scale, const float bias, |
| const int k, queue_ptr stream) { |
| const int num_blocks = (k + SYCL_SCALE_BLOCK_SIZE - 1) / SYCL_SCALE_BLOCK_SIZE; |
| stream->parallel_for( |
| sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * |
| sycl::range<3>(1, 1, SYCL_SCALE_BLOCK_SIZE), |
| sycl::range<3>(1, 1, SYCL_SCALE_BLOCK_SIZE)), |
| [=](sycl::nd_item<3> item_ct1) { |
| scale_f32(x, dst, scale, bias, k, item_ct1); |
| }); |
| } |
|
|
|
|
| static void sum_rows_f32_sycl(const float *x, float *dst, const int ncols, |
| const int nrows, queue_ptr stream) { |
| const sycl::range<3> block_dims(1, 1, WARP_SIZE); |
| const sycl::range<3> block_nums(1, nrows, 1); |
| stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims), |
| [=](sycl::nd_item<3> item_ct1) |
| [[sycl::reqd_sub_group_size(WARP_SIZE)]] { |
| k_sum_rows_f32(x, dst, ncols, item_ct1); |
| }); |
| } |
|
|
| static int next_power_of_2(int x) { |
| int n = 1; |
| while (n < x) { |
| n *= 2; |
| } |
| return n; |
| } |
|
|
| static void argsort_f32_i32_sycl(const float *x, int *dst, const int ncols, |
| const int nrows, ggml_sort_order order, |
| queue_ptr stream, int device) { |
| |
| const int ncols_pad = next_power_of_2(ncols); |
|
|
| int nth = 1; |
| int max_block_size = ggml_sycl_info().max_work_group_sizes[device]; |
| while (nth < ncols_pad && nth < max_block_size) |
| nth *= 2; |
| if (nth > max_block_size) |
| nth = max_block_size; |
|
|
| const int tasks_per_thread = ncols_pad / nth; |
|
|
| const sycl::range<3> block_dims(1, 1, nth); |
| const sycl::range<3> block_nums(1, nrows, 1); |
| const size_t shared_mem = ncols_pad * sizeof(int); |
| GGML_ASSERT(shared_mem<=ggml_sycl_info().devices[device].smpbo); |
|
|
| if (order == GGML_SORT_ORDER_ASC) { |
| stream->submit([&](sycl::handler &cgh) { |
| sycl::local_accessor<uint8_t, 1> dpct_local_acc_ct1( |
| sycl::range<1>(shared_mem), cgh); |
|
|
| cgh.parallel_for( |
| sycl::nd_range<3>(block_nums * block_dims, block_dims), |
| [=](sycl::nd_item<3> item_ct1) { |
| k_argsort_f32_i32<GGML_SORT_ORDER_ASC>( |
| x, dst, ncols, ncols_pad, tasks_per_thread, item_ct1, |
| dpct_local_acc_ct1 |
| .get_multi_ptr<sycl::access::decorated::no>() |
| .get()); |
| }); |
| }); |
| } else if (order == GGML_SORT_ORDER_DESC) { |
| stream->submit([&](sycl::handler &cgh) { |
| sycl::local_accessor<uint8_t, 1> dpct_local_acc_ct1( |
| sycl::range<1>(shared_mem), cgh); |
|
|
| cgh.parallel_for( |
| sycl::nd_range<3>(block_nums * block_dims, block_dims), |
| [=](sycl::nd_item<3> item_ct1) { |
| k_argsort_f32_i32<GGML_SORT_ORDER_DESC>( |
| x, dst, ncols, ncols_pad, tasks_per_thread, item_ct1, |
| dpct_local_acc_ct1 |
| .get_multi_ptr<sycl::access::decorated::no>() |
| .get()); |
| }); |
| }); |
| } else { |
| GGML_ABORT("fatal error"); |
| } |
| } |
|
|
| static void top_k_f32_sycl( |
| const float * src, |
| int32_t * dst_indices, |
| const int64_t ncols, |
| const int64_t nrows, |
| const int k, |
| dpct::queue_ptr main_stream |
| ) { |
| const int block_size = 128; |
|
|
| const sycl::range<1> block_dims(block_size); |
| const sycl::range<1> grid_dims(nrows); |
|
|
| main_stream->submit([&](sycl::handler &cgh) { |
| sycl::local_accessor<float, 1> shared_vals(sycl::range<1>(block_size * k), cgh); |
| sycl::local_accessor<int, 1> shared_idx(sycl::range<1>(block_size * k), cgh); |
|
|
| cgh.parallel_for( |
| sycl::nd_range<1>(grid_dims * block_dims, block_dims), |
| [=](sycl::nd_item<1> item_ct1) { |
| const int row = item_ct1.get_group(0); |
| const int tid = item_ct1.get_local_id(0); |
|
|
| if (row >= nrows) return; |
|
|
| const float * src_row = src + row * ncols; |
| int32_t * dst_idx_row = dst_indices + row * k; |
|
|
| float local_vals[32]; |
| int local_idx[32]; |
|
|
| for (int i = 0; i < k; i++) { |
| local_vals[i] = -FLT_MAX; |
| local_idx[i] = -1; |
| } |
|
|
| for (int col = tid; col < ncols; col += block_size) { |
| float val = src_row[col]; |
|
|
| if (val > local_vals[k-1]) { |
| int pos = k - 1; |
| while (pos > 0 && val > local_vals[pos - 1]) { |
| pos--; |
| } |
|
|
| for (int i = k - 1; i > pos; i--) { |
| local_vals[i] = local_vals[i - 1]; |
| local_idx[i] = local_idx[i - 1]; |
| } |
| local_vals[pos] = val; |
| local_idx[pos] = col; |
| } |
| } |
|
|
| for (int i = 0; i < k; i++) { |
| shared_vals[tid * k + i] = local_vals[i]; |
| shared_idx[tid * k + i] = local_idx[i]; |
| } |
| item_ct1.barrier(sycl::access::fence_space::local_space); |
|
|
| if (tid == 0) { |
| float final_vals[32]; |
| int final_idx[32]; |
|
|
| for (int i = 0; i < k; i++) { |
| final_vals[i] = -FLT_MAX; |
| final_idx[i] = -1; |
| } |
|
|
| for (int t = 0; t < block_size; t++) { |
| for (int i = 0; i < k; i++) { |
| float val = shared_vals[t * k + i]; |
| int idx = shared_idx[t * k + i]; |
|
|
| if (val > final_vals[k-1]) { |
| int pos = k - 1; |
| while (pos > 0 && val > final_vals[pos - 1]) { |
| pos--; |
| } |
|
|
| for (int j = k - 1; j > pos; j--) { |
| final_vals[j] = final_vals[j - 1]; |
| final_idx[j] = final_idx[j - 1]; |
| } |
| final_vals[pos] = val; |
| final_idx[pos] = idx; |
| } |
| } |
| } |
|
|
| for (int i = 0; i < k; i++) { |
| dst_idx_row[i] = final_idx[i]; |
| } |
|
|
| if (k > 1) { |
| int32_t temp = dst_idx_row[0]; |
| dst_idx_row[0] = dst_idx_row[1]; |
| dst_idx_row[1] = temp; |
| } |
| } |
| }); |
| }); |
| } |
|
|
| static void argmax_f32_i32_sycl(const float *x, int *dst, const int ncols, |
| const int nrows, queue_ptr stream) { |
| const sycl::range<3> block_dims(1, 1, SYCL_ARGMAX_BLOCK_SIZE); |
| const sycl::range<3> block_nums(1, nrows, 1); |
| const size_t shared_mem = 256 * sizeof(float); |
|
|
| stream->submit([&](sycl::handler &cgh) { |
| sycl::local_accessor<float, 1> shared_data( |
| sycl::range<1>(shared_mem/sizeof(float)), cgh); |
| sycl::local_accessor<int, 1> shared_indices( |
| sycl::range<1>(shared_mem/sizeof(float)), cgh); |
|
|
| cgh.parallel_for( |
| sycl::nd_range<3>(block_nums * block_dims, block_dims), |
| [=](sycl::nd_item<3> item_ct1) { |
| const int tid = item_ct1.get_local_id(2); |
| const int row = item_ct1.get_global_id(1); |
|
|
| float max_val = -INFINITY; |
| int max_idx = -1; |
|
|
| for (int col = tid; col < ncols; col += 256) { |
| float val = x[row * ncols + col]; |
| if (val > max_val) { |
| max_val = val; |
| max_idx = col; |
| } |
| } |
|
|
| shared_data[tid] = max_val; |
| shared_indices[tid] = max_idx; |
| item_ct1.barrier(sycl::access::fence_space::local_space); |
|
|
| for (int stride = 256/2; stride > 0; stride >>= 1) { |
| if (tid < stride) { |
| float val1 = shared_data[tid]; |
| float val2 = shared_data[tid + stride]; |
| if (val2 > val1) { |
| shared_data[tid] = val2; |
| shared_indices[tid] = shared_indices[tid + stride]; |
| } |
| } |
| item_ct1.barrier(sycl::access::fence_space::local_space); |
| } |
|
|
|
|
| if (tid == 0) { |
| dst[row] = shared_indices[0]; |
| } |
| }); |
| }); |
| } |
| static void diag_mask_inf_f32_sycl(const float *x, float *dst, |
| const int ncols_x, const int nrows_x, |
| const int rows_per_channel, const int n_past, |
| queue_ptr stream) { |
| const sycl::range<3> block_dims(1, SYCL_DIAG_MASK_INF_BLOCK_SIZE, 1); |
| const int block_num_x = (ncols_x + SYCL_DIAG_MASK_INF_BLOCK_SIZE - 1) / SYCL_DIAG_MASK_INF_BLOCK_SIZE; |
| const sycl::range<3> block_nums(1, block_num_x, nrows_x); |
| stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims), |
| [=](sycl::nd_item<3> item_ct1) { |
| diag_mask_inf_f32(x, dst, ncols_x, |
| rows_per_channel, n_past, |
| item_ct1); |
| }); |
| } |
|
|
| static dpct::err0 ggml_sycl_cpy_tensor_2d(void *dst, |
| const struct ggml_tensor *src, |
| int64_t i3, int64_t i2, |
| int64_t i1_low, int64_t i1_high, |
| queue_ptr stream) try { |
|
|
| dpct::memcpy_direction kind; |
| char * src_ptr; |
| if (ggml_backend_buffer_is_host(src->buffer)) { |
| kind = dpct::host_to_device; |
| |
| src_ptr = (char *) src->data; |
| |
| } else if (ggml_backend_buffer_is_sycl(src->buffer)) { |
| |
| |
| kind = dpct::device_to_device; |
| src_ptr = (char *) src->data; |
| } else if (ggml_backend_buffer_is_sycl_split(src->buffer)) { |
| |
| |
| |
| |
| GGML_ASSERT(i1_low == 0 && i1_high == src->ne[1]); |
| kind = dpct::device_to_device; |
| ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) src->extra; |
| int id; |
| SYCL_CHECK(CHECK_TRY_ERROR( |
| id = get_current_device_id())); |
| |
| src_ptr = (char *) extra->data_device[id]; |
| } else { |
| |
| GGML_ABORT("fatal error"); |
| } |
| char * dst_ptr = (char *) dst; |
|
|
| GGML_TENSOR_LOCALS_1(int64_t, ne, src, ne); |
| GGML_TENSOR_LOCALS(int64_t, nb, src, nb); |
| const enum ggml_type type = src->type; |
| const int64_t ts = ggml_type_size(type); |
| const int64_t bs = ggml_blck_size(type); |
| int64_t i1_diff = i1_high - i1_low; |
|
|
| const char * x = src_ptr + i1_low*nb1 + i2*nb2 + i3*nb3; |
| if (nb0 == ts && nb1 == ts*ne0/bs) { |
| |
| |
| return CHECK_TRY_ERROR(dpct::async_dpct_memcpy(dst_ptr, x, i1_diff * nb1, |
| kind, *stream)); |
|
|
| } else if (nb0 == ts) { |
| return CHECK_TRY_ERROR( |
| dpct::async_dpct_memcpy(dst_ptr, ts * ne0 / bs, x, nb1, |
| ts * ne0 / bs, i1_diff, kind, *stream)); |
| } else { |
| for (int64_t i1 = 0; i1 < i1_diff; i1++) { |
| const void * rx = (const void *) ((const char *) x + i1*nb1); |
| void * rd = (void *) (dst_ptr + i1*ts*ne0/bs); |
| |
| dpct::err0 r = CHECK_TRY_ERROR(dpct::async_dpct_memcpy( |
| rd, ts / bs, rx, nb0, ts / bs, ne0, kind, *stream)); |
| |
| |
| |
| |
| |
| |
| |
| if (r != 0) return r; |
| } |
| return 0; |
| } |
| } |
| catch (sycl::exception const &exc) { |
| std::cerr << exc.what() << "Exception caught at file:" << __FILE__ |
| << ", line:" << __LINE__ << std::endl; |
| std::exit(1); |
| } |
|
|
| inline void ggml_sycl_op_mul_mat_sycl( |
| ggml_backend_sycl_context & ctx, |
| const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst, |
| const char *src0_dd_i, const float *src1_ddf_i, const char *src1_ddq_i, |
| float *dst_dd_i, const int64_t row_low, const int64_t row_high, |
| const int64_t src1_ncols, const int64_t src1_padded_row_size, |
| const queue_ptr &stream) try { |
|
|
| GGML_ASSERT(src0_dd_i != nullptr); |
| GGML_ASSERT(src1_ddf_i != nullptr); |
| GGML_ASSERT(dst_dd_i != nullptr); |
|
|
| const int64_t ne00 = src0->ne[0]; |
| const int64_t ne10 = src1->ne[0]; |
| GGML_ASSERT(ne00 == ne10); |
|
|
| const int64_t row_diff = row_high - row_low; |
|
|
| int id; |
| SYCL_CHECK( |
| CHECK_TRY_ERROR(id = get_current_device_id())); |
|
|
| const int64_t ne0 = dst->ne[0]; |
| |
| |
| int ldc = id == ctx.device ? ne0 : row_diff; |
|
|
| #ifdef GGML_SYCL_F16 |
| bool use_fp16 = true; |
| #else |
| bool use_fp16 = false; |
| #endif |
| if ((src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && use_fp16 && ggml_is_contiguous(src0) && |
| row_diff == src0->ne[1] && dst->op_params[0] == GGML_PREC_DEFAULT) { |
| ggml_sycl_pool_alloc<sycl::half> src0_as_f16(ctx.pool()); |
| if (src0->type != GGML_TYPE_F16) { |
| scope_op_debug_print scope_dbg_print(__func__, "/to_fp16_sycl", dst, 2, |
| " : converting src0 to fp16"); |
| const to_fp16_sycl_t to_fp16_sycl = ggml_get_to_fp16_sycl(src0->type, dst); |
| GGML_ASSERT(to_fp16_sycl != nullptr); |
| size_t ne = row_diff*ne00; |
| src0_as_f16.alloc(ne); |
| to_fp16_sycl(src0_dd_i, src0_as_f16.get(), ne, stream); |
| } |
| const sycl::half *src0_ptr = src0->type == GGML_TYPE_F16 |
| ? (const sycl::half *)src0_dd_i |
| : src0_as_f16.get(); |
|
|
| ggml_sycl_pool_alloc<sycl::half> src1_as_f16(ctx.pool()); |
| if (src1->type != GGML_TYPE_F16) { |
| scope_op_debug_print scope_dbg_print(__func__, "/to_fp16_sycl", dst, 2, |
| " : converting src1 to fp16"); |
| const to_fp16_sycl_t to_fp16_sycl = ggml_get_to_fp16_sycl(src1->type, dst); |
| GGML_ASSERT(to_fp16_sycl != nullptr); |
| size_t ne = src1_ncols*ne10; |
| src1_as_f16.alloc(ne); |
| to_fp16_sycl(src1_ddf_i, src1_as_f16.get(), ne, stream); |
| } |
| const sycl::half *src1_ptr = src1->type == GGML_TYPE_F16 |
| ? (const sycl::half *)src1->data + src1_padded_row_size |
| : src1_as_f16.get(); |
|
|
| #if GGML_SYCL_DNNL |
| if (!g_ggml_sycl_disable_dnn) { |
| DnnlGemmWrapper::row_gemm(ctx,row_diff, src1_ncols , ne10, src0_ptr, |
| DnnlGemmWrapper::to_dt<sycl::half>(), src1_ptr, DnnlGemmWrapper::to_dt<sycl::half>(), |
| dst_dd_i, DnnlGemmWrapper::to_dt<float>(), stream); |
| } |
| else |
| #endif |
| { |
| ggml_sycl_pool_alloc<sycl::half> dst_f16(ctx.pool(), row_diff * src1_ncols); |
|
|
| const sycl::half alpha_f16 = 1.0f; |
| const sycl::half beta_f16 = 0.0f; |
| SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm( |
| *stream, oneapi::mkl::transpose::trans, |
| oneapi::mkl::transpose::nontrans, row_diff, src1_ncols, ne10, |
| &alpha_f16, src0_ptr, dpct::library_data_t::real_half, ne00, |
| src1_ptr, dpct::library_data_t::real_half, ne10, &beta_f16, |
| dst_f16.get(), dpct::library_data_t::real_half, ldc, |
| dpct::library_data_t::real_half))); |
| scope_op_debug_print scope_dbg_print(__func__, "/to_fp32_sycl", dst, 2, |
| " : converting dst to fp32"); |
| const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16, dst); |
| to_fp32_sycl(dst_f16.get(), dst_dd_i, row_diff*src1_ncols, stream); |
| } |
| } else { |
| ggml_sycl_pool_alloc<float> src0_ddq_as_f32(ctx.pool()); |
| ggml_sycl_pool_alloc<float> src1_ddq_as_f32(ctx.pool()); |
| if (src0->type != GGML_TYPE_F32) { |
| scope_op_debug_print scope_dbg_print(__func__, "/to_fp32_sycl", dst, 2, |
| " : converting src0 to fp32"); |
| const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(src0->type, dst); |
| GGML_ASSERT(to_fp32_sycl != nullptr); |
| src0_ddq_as_f32.alloc(row_diff*ne00); |
| to_fp32_sycl(src0_dd_i, src0_ddq_as_f32.get(), row_diff*ne00, stream); |
| } |
| if (src1->type != GGML_TYPE_F32) { |
| scope_op_debug_print scope_dbg_print(__func__, "/to_fp32_sycl", dst, 2, |
| " : converting src1 to fp32"); |
| const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(src1->type, dst); |
| GGML_ASSERT(to_fp32_sycl != nullptr); |
| src1_ddq_as_f32.alloc(src1_ncols*ne10); |
| to_fp32_sycl(src1_ddf_i, src1_ddq_as_f32.get(), src1_ncols*ne10, stream); |
| } |
| const float * src0_ddf_i = src0->type == GGML_TYPE_F32 ? (const float *) src0_dd_i : src0_ddq_as_f32.get(); |
| const float * src1_ddf1_i = src1->type == GGML_TYPE_F32 ? (const float *) src1_ddf_i : src1_ddq_as_f32.get(); |
|
|
| #if GGML_SYCL_DNNL |
| if (!g_ggml_sycl_disable_dnn) { |
| DnnlGemmWrapper::row_gemm(ctx, row_diff, src1_ncols, ne10, src0_ddf_i, |
| DnnlGemmWrapper::to_dt<float>(), src1_ddf1_i, DnnlGemmWrapper::to_dt<float>(), |
| dst_dd_i, DnnlGemmWrapper::to_dt<float>(), stream); |
| } |
| else |
| #endif |
| { |
| const float alpha = 1.0f; |
| const float beta = 0.0f; |
| SYCL_CHECK(CHECK_TRY_ERROR(oneapi::mkl::blas::column_major::gemm( |
| *stream, oneapi::mkl::transpose::trans, oneapi::mkl::transpose::nontrans, row_diff, |
| src1_ncols, ne10, dpct::get_value(&alpha, *stream), src0_ddf_i, ne00, src1_ddf1_i, ne10, |
| dpct::get_value(&beta, *stream), dst_dd_i, ldc))); |
| } |
| } |
| GGML_UNUSED(dst); |
| GGML_UNUSED(src1_ddq_i); |
| GGML_UNUSED(src1_padded_row_size); |
| } |
| catch (sycl::exception const &exc) { |
| std::cerr << exc.what() << "Exception caught at file:" << __FILE__ |
| << ", line:" << __LINE__ << std::endl; |
| std::exit(1); |
| } |
|
|
| static void ggml_sycl_op_pool2d(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { |
| GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); |
| GGML_ASSERT( dst->type == GGML_TYPE_F32); |
| dpct::queue_ptr main_stream = ctx.stream(); |
| SYCL_CHECK(ggml_sycl_set_device(ctx.device)); |
| const float * src0_dd = static_cast<const float *>(dst->src[0]->data); |
| float * dst_dd = static_cast<float *>(dst->data); |
|
|
| const int32_t * opts = (const int32_t *)dst->op_params; |
| enum ggml_op_pool op = static_cast<ggml_op_pool>(opts[0]); |
| const int k0 = opts[1]; |
| const int k1 = opts[2]; |
| const int s0 = opts[3]; |
| const int s1 = opts[4]; |
| const int p0 = opts[5]; |
| const int p1 = opts[6]; |
|
|
| const int64_t IH = dst->src[0]->ne[1]; |
| const int64_t IW = dst->src[0]->ne[0]; |
|
|
| const int64_t N = dst->ne[3]; |
| const int64_t OC = dst->ne[2]; |
| const int64_t OH = dst->ne[1]; |
| const int64_t OW = dst->ne[0]; |
|
|
| const int parallel_elements = N * OC * OH * OW; |
| const int num_blocks = (parallel_elements + SYCL_POOL2D_BLOCK_SIZE - 1) / SYCL_POOL2D_BLOCK_SIZE; |
| sycl::range<3> block_nums(1, 1, num_blocks); |
| main_stream->parallel_for( |
| sycl::nd_range<3>(block_nums * |
| sycl::range<3>(1, 1, SYCL_IM2COL_BLOCK_SIZE), |
| sycl::range<3>(1, 1, SYCL_IM2COL_BLOCK_SIZE)), |
| [=](sycl::nd_item<3> item_ct1) { |
| pool2d_nchw_kernel(IH, IW, OH, OW, k1, k0, s1, s0, p1, p0, |
| parallel_elements, src0_dd, dst_dd, op, |
| item_ct1); |
| }); |
| } |
|
|
| inline void ggml_sycl_op_sum(ggml_backend_sycl_context & ctx, ggml_tensor *dst) { |
| GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); |
| GGML_ASSERT( dst->type == GGML_TYPE_F32); |
| dpct::queue_ptr main_stream = ctx.stream(); |
| SYCL_CHECK(ggml_sycl_set_device(ctx.device)); |
| const float * src0_dd = static_cast<const float *>(dst->src[0]->data); |
| float * dst_dd = static_cast<float *>(dst->data); |
|
|
| const int64_t ne = ggml_nelements(dst->src[0]); |
|
|
| sum_rows_f32_sycl(src0_dd, dst_dd, ne, 1, main_stream); |
| } |
|
|
| inline void ggml_sycl_op_sum_rows(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { |
| GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); |
| GGML_ASSERT( dst->type == GGML_TYPE_F32); |
| dpct::queue_ptr main_stream = ctx.stream(); |
| SYCL_CHECK(ggml_sycl_set_device(ctx.device)); |
| const float * src0_dd = static_cast<const float *>(dst->src[0]->data); |
| float * dst_dd = static_cast<float *>(dst->data); |
|
|
| const int64_t ncols = dst->src[0]->ne[0]; |
| const int64_t nrows = ggml_nrows(dst->src[0]); |
|
|
| sum_rows_f32_sycl(src0_dd, dst_dd, ncols, nrows, main_stream); |
| } |
|
|
| inline void ggml_sycl_op_mean(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { |
| GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); |
| GGML_ASSERT(dst->type == GGML_TYPE_F32); |
|
|
| dpct::queue_ptr main_stream = ctx.stream(); |
| SYCL_CHECK(ggml_sycl_set_device(ctx.device)); |
|
|
| const float * src0_dd = static_cast<const float *>(dst->src[0]->data); |
| float * dst_dd = static_cast<float *>(dst->data); |
|
|
| const int64_t ncols = dst->src[0]->ne[0]; |
| const int64_t nrows = ggml_nrows(dst->src[0]); |
|
|
| sum_rows_f32_sycl(src0_dd, dst_dd, ncols, nrows, main_stream); |
|
|
| main_stream->parallel_for( |
| sycl::range<1>(nrows), |
| [=](sycl::id<1> row) { |
| dst_dd[row] /= ncols; |
| } |
| ); |
| } |
|
|
|
|
| inline void ggml_sycl_op_argsort(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { |
| GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); |
| GGML_ASSERT(dst->type == GGML_TYPE_I32); |
| dpct::queue_ptr main_stream = ctx.stream(); |
| SYCL_CHECK(ggml_sycl_set_device(ctx.device)); |
| const float * src0_dd = static_cast<const float *>(dst->src[0]->data); |
| int32_t * dst_dd = static_cast<int32_t *>(dst->data); |
|
|
|
|
| const int64_t ncols = dst->src[0]->ne[0]; |
| const int64_t nrows = ggml_nrows(dst->src[0]); |
|
|
| enum ggml_sort_order order = (enum ggml_sort_order) dst->op_params[0]; |
|
|
| argsort_f32_i32_sycl(src0_dd, (int *)dst_dd, ncols, nrows, order, |
| main_stream, ctx.device); |
| } |
|
|
| static void ggml_sycl_op_top_k(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { |
| const ggml_tensor * src0 = dst->src[0]; |
|
|
| GGML_ASSERT(src0); |
| GGML_ASSERT(src0->type == GGML_TYPE_F32); |
| GGML_ASSERT(dst->type == GGML_TYPE_I32); |
| GGML_ASSERT(ggml_is_contiguous(src0)); |
|
|
| dpct::queue_ptr main_stream = ctx.stream(); |
| SYCL_CHECK(ggml_sycl_set_device(ctx.device)); |
|
|
| const float * src0_dd = static_cast<const float *>(src0->data); |
| int32_t * dst_dd = static_cast<int32_t *>(dst->data); |
|
|
| const int k = dst->ne[0]; |
| const int64_t ncols = src0->ne[0]; |
| const int64_t nrows = ggml_nrows(src0); |
|
|
| GGML_ASSERT(k > 0 && k <= 32); |
| GGML_ASSERT(k <= ncols); |
|
|
| top_k_f32_sycl(src0_dd, dst_dd, ncols, nrows, k, main_stream); |
| } |
|
|
| inline void ggml_sycl_op_argmax(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { |
| GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); |
| GGML_ASSERT( dst->type == GGML_TYPE_I32); |
|
|
| dpct::queue_ptr main_stream = ctx.stream(); |
| SYCL_CHECK(ggml_sycl_set_device(ctx.device)); |
| const float * src0_dd = static_cast<const float *>(dst->src[0]->data); |
| int32_t * dst_dd = static_cast<int32_t *>(dst->data); |
|
|
| const int64_t ncols = dst->src[0]->ne[0]; |
| const int64_t nrows = ggml_nrows(dst->src[0]); |
|
|
| argmax_f32_i32_sycl(src0_dd, dst_dd, ncols, nrows, main_stream); |
| } |
|
|
| inline void ggml_sycl_op_diag_mask_inf(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { |
| GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); |
| GGML_ASSERT( dst->type == GGML_TYPE_F32); |
| dpct::queue_ptr main_stream = ctx.stream(); |
| SYCL_CHECK(ggml_sycl_set_device(ctx.device)); |
| const float * src0_dd = static_cast<const float *>(dst->src[0]->data); |
| float * dst_dd = static_cast<float *>(dst->data); |
|
|
| const int64_t ne00 = dst->src[0]->ne[0]; |
| const int64_t ne01 = dst->src[0]->ne[1]; |
| const int nrows0 = ggml_nrows(dst->src[0]); |
|
|
| const int n_past = ((int32_t *) dst->op_params)[0]; |
|
|
| diag_mask_inf_f32_sycl(src0_dd, dst_dd, ne00, nrows0, ne01, n_past, main_stream); |
| } |
|
|
| static void tri_f32_sycl( |
| const float * src, |
| float * dst, |
| const int64_t ne0, |
| const int64_t ne1, |
| const int64_t ne2, |
| const int64_t ne3, |
| const ggml_tri_type ttype, |
| dpct::queue_ptr main_stream |
| ) { |
| const size_t total = (size_t) ne0 * (size_t) ne1 * (size_t) ne2 * (size_t) ne3; |
|
|
| main_stream->parallel_for(sycl::range<1>(total), [=](sycl::id<1> tid) { |
| const int64_t idx = (int64_t) tid[0]; |
|
|
| const int64_t i0 = idx % ne0; |
| const int64_t t1 = idx / ne0; |
| const int64_t i1 = t1 % ne1; |
|
|
| bool keep = false; |
| switch (ttype) { |
| case GGML_TRI_TYPE_LOWER: keep = (i0 < i1); break; |
| case GGML_TRI_TYPE_LOWER_DIAG: keep = (i0 <= i1); break; |
| case GGML_TRI_TYPE_UPPER: keep = (i0 > i1); break; |
| case GGML_TRI_TYPE_UPPER_DIAG: keep = (i0 >= i1); break; |
| default: keep = false; break; |
| } |
|
|
| dst[idx] = keep ? src[idx] : 0.0f; |
| }); |
| } |
|
|
| static void ggml_sycl_op_tri(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { |
| const ggml_tensor * src0 = dst->src[0]; |
| GGML_ASSERT(src0); |
|
|
| GGML_ASSERT(src0->type == GGML_TYPE_F32); |
| GGML_ASSERT(dst->type == GGML_TYPE_F32); |
| GGML_ASSERT(ggml_is_contiguous(src0)); |
| GGML_ASSERT(ggml_is_contiguous(dst)); |
| GGML_ASSERT(ggml_are_same_shape(src0, dst)); |
|
|
| dpct::queue_ptr main_stream = ctx.stream(); |
| SYCL_CHECK(ggml_sycl_set_device(ctx.device)); |
|
|
| const float * src0_dd = static_cast<const float *>(src0->data); |
| float * dst_dd = static_cast<float *>(dst->data); |
|
|
| const ggml_tri_type ttype = (ggml_tri_type) ggml_get_op_params_i32(dst, 0); |
|
|
| const int64_t ne0 = src0->ne[0]; |
| const int64_t ne1 = src0->ne[1]; |
| const int64_t ne2 = src0->ne[2]; |
| const int64_t ne3 = src0->ne[3]; |
|
|
| tri_f32_sycl(src0_dd, dst_dd, ne0, ne1, ne2, ne3, ttype, main_stream); |
| } |
|
|
|
|
| inline void ggml_sycl_op_scale(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { |
| GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); |
| GGML_ASSERT( dst->type == GGML_TYPE_F32); |
| dpct::queue_ptr main_stream = ctx.stream(); |
| SYCL_CHECK(ggml_sycl_set_device(ctx.device)); |
| const float * src0_dd = static_cast<const float *>(dst->src[0]->data); |
| float * dst_dd = static_cast<float *>(dst->data); |
|
|
| float scale; |
| float bias; |
| memcpy(&scale, (float *) dst->op_params + 0, sizeof(float)); |
| memcpy(&bias, (float *) dst->op_params + 1, sizeof(float)); |
|
|
| scale_f32_sycl(src0_dd, dst_dd, scale, bias, ggml_nelements(dst->src[0]), main_stream); |
| |
| |
| |
| |
| SYCL_CHECK(0); |
| } |
|
|
| static void ggml_sycl_set_peer_access(const int n_tokens, int main_device) { |
| static bool peer_access_enabled = false; |
|
|
| const bool enable_peer_access = n_tokens <= GGML_SYCL_PEER_MAX_BATCH_SIZE; |
|
|
| if (peer_access_enabled == enable_peer_access) { |
| return; |
| } |
|
|
| #ifdef NDEBUG |
| for (int i = 0; i < ggml_sycl_info().device_count; ++i) { |
| SYCL_CHECK(ggml_sycl_set_device(i)); |
| } |
|
|
| for (int i = 0; i < ggml_sycl_info().device_count; ++i) { |
| SYCL_CHECK(ggml_sycl_set_device(i)); |
|
|
| for (int id_other = 0; id_other < ggml_sycl_info().device_count; ++id_other) { |
| if (i == id_other) { |
| continue; |
| } |
| if (i != main_device && id_other != main_device) { |
| continue; |
| } |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| } |
| } |
| #endif |
|
|
| peer_access_enabled = enable_peer_access; |
| } |
|
|
| template <template <int> typename quantize_f> |
| static void ggml_sycl_op_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, |
| const ggml_tensor *src1, ggml_tensor *dst, |
| ggml_sycl_op_mul_mat_t op) try { |
|
|
| GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne); |
|
|
| GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne); |
| const int64_t nrows1 = ggml_nrows(src1); |
|
|
| GGML_ASSERT(ne03 == ne13); |
|
|
| const int64_t ne0 = dst->ne[0]; |
| const int64_t ne1 = dst->ne[1]; |
|
|
| const int nb2 = dst->nb[2]; |
| const int nb3 = dst->nb[3]; |
|
|
| GGML_ASSERT(!ggml_backend_buffer_is_sycl_split(dst->buffer)); |
| GGML_ASSERT(!ggml_backend_buffer_is_sycl_split(src1->buffer)); |
| GGML_ASSERT(src1->type == GGML_TYPE_F32 || (src1->ne[2] == 1 && src1->ne[3] == 1)); |
|
|
| GGML_ASSERT(ne12 >= ne02 && ne12 % ne02 == 0); |
|
|
| const int64_t i02_divisor = ne12 / ne02; |
|
|
| const size_t src0_ts = ggml_type_size(src0->type); |
| const size_t src0_bs = ggml_blck_size(src0->type); |
| const size_t q8_1_ts = sizeof(block_q8_1); |
| const size_t q8_1_bs = QK8_1; |
|
|
| ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra; |
| ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) src1->extra; |
|
|
| const bool src0_is_contiguous = ggml_is_contiguous(src0); |
| const bool src1_is_contiguous = ggml_is_contiguous(src1); |
|
|
| int64_t src1_padded_col_size = GGML_PAD(ne10, MATRIX_ROW_PADDING); |
|
|
| const bool split = ggml_backend_buffer_is_sycl_split(src0->buffer); |
| GGML_ASSERT(!(split && ne02 > 1)); |
| GGML_ASSERT(!(split && ne03 > 1)); |
| GGML_ASSERT(!(split && ne02 < ne12)); |
|
|
| std::array<float, GGML_SYCL_MAX_DEVICES> tensor_split; |
| if (split) { |
| |
| |
| ggml_backend_sycl_split_buffer_type_context * buft_ctx = (ggml_backend_sycl_split_buffer_type_context *) src0->buffer->buft->context; |
| tensor_split = buft_ctx->tensor_split; |
| } |
|
|
| struct dev_data { |
| ggml_sycl_pool_alloc<char> src0_dd_alloc; |
| ggml_sycl_pool_alloc<float> src1_ddf_alloc; |
| ggml_sycl_pool_alloc<char> src1_ddq_alloc; |
| ggml_sycl_pool_alloc<float> dst_dd_alloc; |
|
|
| char *src0_dd = nullptr; |
| float *src1_ddf = nullptr; |
| char *src1_ddq = nullptr; |
| float *dst_dd = nullptr; |
|
|
| int64_t row_low; |
| int64_t row_high; |
| }; |
|
|
| dev_data dev[GGML_SYCL_MAX_DEVICES]; |
|
|
| int used_devices = 0; |
| queue_ptr main_stream = ctx.stream(); |
|
|
| for (int i = 0; i < ggml_sycl_info().device_count; ++i) { |
| |
| dev[i].row_low = 0; |
| dev[i].row_high = ne01; |
|
|
| |
| |
| if (split) { |
| const int64_t rounding = get_row_rounding(src0->type, tensor_split); |
|
|
| if (i != 0) { |
| dev[i].row_low = ne01*tensor_split[i]; |
| if (dev[i].row_low < ne01) { |
| dev[i].row_low -= dev[i].row_low % rounding; |
| } |
| } |
|
|
| if (i != ggml_sycl_info().device_count - 1) { |
| dev[i].row_high = ne01*tensor_split[i + 1]; |
| if (dev[i].row_high < ne01) { |
| dev[i].row_high -= dev[i].row_high % rounding; |
| } |
| } |
| } |
| } |
|
|
| constexpr bool quantize_enabled = !std::is_same_v<quantize_f<QK8_1 / WARP_SIZE>, |
| no_quantize_q8_1<QK8_1 / WARP_SIZE>>; |
| for (int i = 0; i < ggml_sycl_info().device_count; ++i) { |
| if ((!split && i != ctx.device) || dev[i].row_low == dev[i].row_high) { |
| continue; |
| } |
|
|
| used_devices++; |
|
|
| const bool src1_on_device = i == ctx.device; |
| const bool dst_on_device = i == ctx.device; |
|
|
| ggml_sycl_set_device(i); |
| queue_ptr stream = ctx.stream(i, 0); |
|
|
| if (src0_is_contiguous) { |
| dev[i].src0_dd = (char *) src0->data; |
| } else { |
| dev[i].src0_dd = dev[i].src0_dd_alloc.alloc(ctx.pool(i), ggml_nbytes(src0)); |
| } |
|
|
| if (src1_on_device && src1_is_contiguous) { |
| dev[i].src1_ddf = (float *) src1->data; |
| } else { |
| dev[i].src1_ddf = dev[i].src1_ddf_alloc.alloc(ctx.pool(i), ggml_nelements(src1)); |
| } |
|
|
| if constexpr(quantize_enabled) { |
| dev[i].src1_ddq = dev[i].src1_ddq_alloc.alloc(ctx.pool(i), nrows1*src1_padded_col_size*q8_1_ts/q8_1_bs); |
|
|
| if (src1_on_device && src1_is_contiguous) { |
| scope_op_debug_print scope_dbg_print(__func__, "/quantize_row_q8_1_sycl", dst, |
| 2, " : converting src1 to Q8_1"); |
| try { |
| quantize_row_q8_1_sycl<quantize_f>(dev[i].src1_ddf, dev[i].src1_ddq, ne10, nrows1, src1_padded_col_size, stream); |
| } catch (sycl::exception const &exc) { |
| std::cerr << "Quantize_row_q8_1_sycl error" << exc.what() << "Exception caught at file:" << __FILE__ |
| << ", line:" << __LINE__ << std::endl; |
| std::exit(1); |
| } |
| } |
| } |
|
|
| if (dst_on_device) { |
| dev[i].dst_dd = (float *) dst->data; |
| } else { |
| const size_t size_dst_ddf = split ? (dev[i].row_high - dev[i].row_low)*ne1 : ggml_nelements(dst); |
| dev[i].dst_dd = dev[i].dst_dd_alloc.alloc(ctx.pool(i), size_dst_ddf); |
| } |
| } |
|
|
| |
| |
| if (split && used_devices > 1) { |
| ggml_sycl_set_device(ctx.device); |
| SYCL_CHECK(CHECK_TRY_ERROR( |
| *src0_extra->events[ctx.device][0] = |
| ctx.stream()->ext_oneapi_submit_barrier())); |
| } |
|
|
| const int64_t src1_col_stride = split && used_devices > 1 ? MUL_MAT_SRC1_COL_STRIDE : ne11; |
| for (int64_t src1_col_0 = 0; src1_col_0 < ne11; src1_col_0 += src1_col_stride) { |
| const int64_t is = split ? (src1_col_0/src1_col_stride) % GGML_SYCL_MAX_STREAMS : 0; |
| const int64_t src1_ncols = src1_col_0 + src1_col_stride > ne11 ? ne11 - src1_col_0 : src1_col_stride; |
| for (int i = 0; i < ggml_sycl_info().device_count; ++i) { |
| if ((!split && i != ctx.device) || dev[i].row_low == dev[i].row_high) { |
| continue; |
| } |
|
|
| const bool src1_on_device = i == ctx.device; |
| const bool dst_on_device = i == ctx.device; |
| const int64_t row_diff = dev[i].row_high - dev[i].row_low; |
|
|
| ggml_sycl_set_device(i); |
| queue_ptr stream = ctx.stream(i, is); |
|
|
| |
| if (split && (i != ctx.device || is != 0)) { |
| SYCL_CHECK(CHECK_TRY_ERROR(stream->ext_oneapi_submit_barrier( |
| {*src0_extra->events[ctx.device][0]}))); |
| } |
|
|
| for (int64_t i0 = 0; i0 < ne13*ne12; ++i0) { |
| const int64_t i03 = i0 / ne12; |
| const int64_t i02 = i0 % ne12; |
|
|
| const size_t src1_ddq_i_offset = (i0*ne11 + src1_col_0) * src1_padded_col_size*q8_1_ts/q8_1_bs; |
|
|
| |
| char * src0_dd_i = dev[i].src0_dd + (i0/i02_divisor) * (ne01*ne00*src0_ts)/src0_bs; |
| float * src1_ddf_i = dev[i].src1_ddf + (i0*ne11 + src1_col_0) * ne10; |
| char * src1_ddq_i = dev[i].src1_ddq + src1_ddq_i_offset; |
| float * dst_dd_i = dev[i].dst_dd + (i0*ne1 + src1_col_0) * (dst_on_device ? ne0 : row_diff); |
|
|
| |
| |
| if (i == ctx.device) { |
| dst_dd_i += dev[i].row_low; |
| } |
|
|
| |
| if (src1_is_contiguous) { |
| if (i != ctx.device) { |
| if constexpr (quantize_enabled) { |
| char * src1_ddq_i_source = dev[ctx.device].src1_ddq + src1_ddq_i_offset; |
| SYCL_CHECK( |
| CHECK_TRY_ERROR(stream |
| ->memcpy(src1_ddq_i, src1_ddq_i_source, |
| src1_ncols * src1_padded_col_size * q8_1_ts / q8_1_bs) |
| .wait())); |
| } else { |
| float * src1_ddf_i_source = (float *) src1_extra->data_device[ctx.device]; |
| src1_ddf_i_source += (i0 * ne11 + src1_col_0) * ne10; |
|
|
| SYCL_CHECK( |
| CHECK_TRY_ERROR(dev2dev_memcpy(*stream, *main_stream, src1_ddf_i, src1_ddf_i_source, |
| src1_ncols * ne10 * sizeof(float)))); |
| } |
| } |
| } else { |
| if (src1_on_device) { |
| SYCL_CHECK(ggml_sycl_cpy_tensor_2d(src1_ddf_i, src1, i03, i02, src1_col_0, |
| src1_col_0 + src1_ncols, stream)); |
| } else { |
| GGML_ABORT("src1 is non-contiguous and not on device"); |
| } |
|
|
| if constexpr (quantize_enabled) { |
| scope_op_debug_print scope_dbg_print(__func__, "/quantize_row_q8_1_sycl", dst, |
| 2, " : converting src1 to Q8_1"); |
| try { |
| quantize_row_q8_1_sycl<quantize_q8_1>(src1_ddf_i, src1_ddq_i, ne10, src1_ncols, |
| src1_padded_col_size, stream); |
| } catch (const sycl::exception & exc) { |
| std::cerr << "Quantize_row_q8_1_sycl error" << exc.what() |
| << "Exception caught at file:" << __FILE__ << ", line:" << __LINE__ << std::endl; |
| std::exit(1); |
| } |
| } |
| } |
|
|
| if (src1_col_0 == 0 && !src0_is_contiguous && i02 % i02_divisor == 0) { |
| SYCL_CHECK(ggml_sycl_cpy_tensor_2d(src0_dd_i, src0, i03, i02/i02_divisor, dev[i].row_low, dev[i].row_high, stream)); |
| } |
| if (src1->type == GGML_TYPE_F16) { |
| src1_padded_col_size = (i0 * ne11 + src1_col_0) * ne10; |
| } |
| |
| SYCL_CHECK(CHECK_TRY_ERROR(op(ctx, src0, src1, dst, src0_dd_i, src1_ddf_i, src1_ddq_i, dst_dd_i, |
| dev[i].row_low, dev[i].row_high, src1_ncols, src1_padded_col_size, stream))); |
|
|
| |
| if (!dst_on_device) { |
| void * dst_off_device = dst->data; |
| if (split) { |
| |
| |
| |
| |
| |
| float * dhf_dst_i = (float *) ((char *) dst_off_device + i02*nb2 + i03*nb3); |
| GGML_ASSERT(dst->nb[1] == ne0*sizeof(float)); |
| dhf_dst_i += src1_col_0*ne0 + dev[i].row_low; |
|
|
| SYCL_CHECK(CHECK_TRY_ERROR(dpct::async_dpct_memcpy( |
| dhf_dst_i, ne0 * sizeof(float), dst_dd_i, |
| row_diff * sizeof(float), row_diff * sizeof(float), |
| src1_ncols, dpct::device_to_device, *stream))); |
| } else { |
| float * dhf_dst_i = (float *) ((char *) dst_off_device + i02*nb2 + i03*nb3); |
| GGML_ASSERT(dst->nb[1] == ne0*sizeof(float)); |
| dhf_dst_i += src1_col_0*ne0; |
| SYCL_CHECK(CHECK_TRY_ERROR( |
| stream->memcpy(dhf_dst_i, dst_dd_i, |
| src1_ncols * ne0 * sizeof(float)).wait())); |
| } |
| } |
|
|
| |
| if (split && (i != ctx.device || is != 0)) { |
| SYCL_CHECK(CHECK_TRY_ERROR( |
| *src0_extra->events[i][is] = |
| stream->ext_oneapi_submit_barrier())); |
| } |
| } |
| } |
| } |
|
|
| |
| if (split && ggml_sycl_info().device_count > 1) { |
| int64_t is_max = (ne11 + MUL_MAT_SRC1_COL_STRIDE - 1) / MUL_MAT_SRC1_COL_STRIDE; |
| is_max = is_max <= GGML_SYCL_MAX_STREAMS ? is_max : GGML_SYCL_MAX_STREAMS; |
|
|
| ggml_sycl_set_device(ctx.device); |
| for (int i = 0; i < ggml_sycl_info().device_count; ++i) { |
| if (dev[i].row_low == dev[i].row_high) { |
| continue; |
| } |
| for (int64_t is = 0; is < is_max; ++is) { |
| SYCL_CHECK(CHECK_TRY_ERROR( |
| ctx.stream()->ext_oneapi_submit_barrier( |
| {*src0_extra->events[i][is]}))); |
| } |
| } |
| } |
| } |
| catch (sycl::exception const &exc) { |
| std::cerr << exc.what() << "Exception caught at file:" << __FILE__ |
| << ", line:" << __LINE__ << std::endl; |
| std::exit(1); |
| } |
|
|
| static void ggml_sycl_repeat_back(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { |
| scope_op_debug_print scope_dbg_print(__func__, dst, 1); |
| ggml_sycl_op_repeat_back(ctx, dst); |
| } |
|
|
| static void ggml_sycl_get_rows(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { |
| scope_op_debug_print scope_dbg_print(__func__, dst, 2); |
| ggml_sycl_op_get_rows(ctx, dst); |
| } |
|
|
| static void ggml_sycl_norm(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { |
| scope_op_debug_print scope_dbg_print(__func__, dst, 1); |
| ggml_sycl_op_norm(ctx, dst); |
| } |
|
|
| static void ggml_sycl_rms_norm(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { |
| scope_op_debug_print scope_dbg_print(__func__, dst, 1); |
| ggml_sycl_op_rms_norm(ctx, dst); |
| } |
|
|
| static void ggml_sycl_rms_norm_back(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { |
| scope_op_debug_print scope_dbg_print(__func__, dst, 2); |
| ggml_sycl_op_rms_norm_back(ctx, dst); |
| } |
|
|
| static void ggml_sycl_l2_norm(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { |
| scope_op_debug_print scope_dbg_print(__func__, dst, 1); |
| ggml_sycl_op_l2_norm(ctx, dst); |
| } |
|
|
| static void ggml_sycl_group_norm(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { |
| scope_op_debug_print scope_dbg_print(__func__, dst, 1); |
| ggml_sycl_op_group_norm(ctx, dst); |
| } |
|
|
| static void ggml_sycl_mul_mat_vec_p021(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, |
| const ggml_tensor *src1, |
| ggml_tensor *dst) try { |
| GGML_ASSERT(ggml_is_permuted(src0) && ggml_is_permuted(src1)); |
| GGML_ASSERT(!ggml_backend_buffer_is_sycl_split(src0->buffer)); |
| GGML_ASSERT(src0->nb[0] <= src0->nb[1] && src0->nb[2] <= src0->nb[3]); |
| GGML_ASSERT(src1->nb[0] <= src1->nb[1] && src1->nb[2] <= src1->nb[3]); |
| GGML_ASSERT(src0->type == GGML_TYPE_F16); |
| GGML_ASSERT(src1->type == GGML_TYPE_F32); |
|
|
| const int64_t ne00 = src0->ne[0]; |
| const int64_t ne01 = src0->ne[1]; |
| const int64_t ne02 = src0->ne[2]; |
|
|
| const int64_t ne12 = src1->ne[2]; |
|
|
| SYCL_CHECK(ggml_sycl_set_device(ctx.device)); |
| queue_ptr main_stream = ctx.stream(); |
|
|
| void * src0_ddq = src0->data; |
| float * src1_ddf = (float *) src1->data; |
| float * dst_ddf = (float *) dst->data; |
|
|
| ggml_mul_mat_p021_f16_f32_sycl(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, ne02, ne12, main_stream); |
| } |
| catch (sycl::exception const &exc) { |
| std::cerr << exc.what() << "Exception caught at file:" << __FILE__ |
| << ", line:" << __LINE__ << std::endl; |
| std::exit(1); |
| } |
|
|
| static void ggml_sycl_mul_mat_vec_nc(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, |
| const ggml_tensor *src1, |
| ggml_tensor *dst) try { |
| GGML_ASSERT(!ggml_is_transposed(src0)); |
| GGML_ASSERT(!ggml_is_transposed(src1)); |
| GGML_ASSERT(!ggml_is_permuted(src0)); |
| GGML_ASSERT(!ggml_backend_buffer_is_sycl_split(src0->buffer)); |
| GGML_ASSERT(src0->type == GGML_TYPE_F16); |
| GGML_ASSERT(src1->type == GGML_TYPE_F32); |
| GGML_ASSERT(src1->ne[1] == 1); |
| GGML_ASSERT(src1->ne[3] == 1); |
|
|
| const int64_t ne00 = src0->ne[0]; |
| const int64_t ne01 = src0->ne[1]; |
| const int64_t ne02 = src0->ne[2]; |
|
|
| const int64_t nb01 = src0->nb[1]; |
| const int64_t nb02 = src0->nb[2]; |
|
|
| const int64_t ne12 = src1->ne[2]; |
| const int64_t nb11 = src1->nb[1]; |
|
|
| SYCL_CHECK(ggml_sycl_set_device(ctx.device)); |
| queue_ptr main_stream = ctx.stream(); |
|
|
| void * src0_ddq = src0->data; |
| float * src1_ddf = (float *) src1->data; |
| float * dst_ddf = (float *) dst->data; |
|
|
| const int64_t row_stride_x = nb01 / sizeof(sycl::half); |
| const int64_t channel_stride_x = nb02 / sizeof(sycl::half); |
| const int64_t channel_stride_y = nb11 / sizeof(float); |
|
|
| ggml_mul_mat_vec_nc_f16_f32_sycl(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, row_stride_x, ne02, ne12, channel_stride_x,channel_stride_y, main_stream); |
| } |
| catch (sycl::exception const &exc) { |
| std::cerr << exc.what() << "Exception caught at file:" << __FILE__ |
| << ", line:" << __LINE__ << std::endl; |
| std::exit(1); |
| } |
|
|
| static void k_compute_batched_ptrs(const sycl::half * src0_as_f16, const sycl::half * src1_as_f16, void * dst, |
| const void ** ptrs_src, void ** ptrs_dst, int64_t ne12, int64_t ne13, int64_t ne23, |
| size_t nb02, size_t nb03, size_t nb12, size_t nb13, size_t nbd2, size_t nbd3, |
| int64_t r2, int64_t r3, const sycl::nd_item<3> & item_ct1) { |
| const int64_t i13 = item_ct1.get_group(2) * item_ct1.get_local_range(2) + item_ct1.get_local_id(2); |
| const int64_t i12 = item_ct1.get_group(1) * item_ct1.get_local_range(1) + item_ct1.get_local_id(1); |
|
|
| if (i13 >= ne13 || i12 >= ne12) { |
| return; |
| } |
|
|
| const int64_t i03 = i13 / r3; |
| const int64_t i02 = i12 / r2; |
|
|
| const uint8_t * src0_bytes = reinterpret_cast<const uint8_t *>(src0_as_f16); |
| const uint8_t * src1_bytes = reinterpret_cast<const uint8_t *>(src1_as_f16); |
| uint8_t * dst_bytes = static_cast<uint8_t *>(dst); |
|
|
| ptrs_src[0 * ne23 + i12 + i13 * ne12] = src0_bytes + i02 * nb02 + i03 * nb03; |
| ptrs_src[1 * ne23 + i12 + i13 * ne12] = src1_bytes + i12 * nb12 + i13 * nb13; |
| ptrs_dst[0 * ne23 + i12 + i13 * ne12] = dst_bytes + i12 * nbd2 + i13 * nbd3; |
| } |
|
|
| static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, |
| const ggml_tensor * src1, ggml_tensor * dst) try { |
| GGML_ASSERT(!ggml_is_transposed(src0)); |
| GGML_ASSERT(!ggml_is_transposed(src1)); |
| GGML_ASSERT(!ggml_backend_buffer_is_sycl_split(src0->buffer)); |
| GGML_ASSERT(src0->type == GGML_TYPE_F16); |
| GGML_ASSERT(dst->type == GGML_TYPE_F32); |
|
|
| GGML_TENSOR_BINARY_OP_LOCALS |
|
|
| |
| |
| GGML_ASSERT(ggml_is_contiguous(dst)); |
|
|
| SYCL_CHECK(ggml_sycl_set_device(ctx.device)); |
| queue_ptr queue = ctx.stream(); |
|
|
| dpct::has_capability_or_fail(queue->get_device(), { sycl::aspect::fp16 }); |
|
|
| const sycl::half * src0_f16 = static_cast<const sycl::half *>(src0->data); |
| float * dst_ddf = static_cast<float *>(dst->data); |
|
|
| const sycl::half * src1_f16 = static_cast<const sycl::half *>(src1->data); |
| const size_t type_size_src0 = ggml_type_size(src0->type); |
| const size_t type_size_src1 = ggml_type_size(src1->type); |
|
|
| bool is_src0_cont_2 = ggml_is_contiguous_2(src0); |
| bool is_src1_cont_2 = ggml_is_contiguous_2(src1); |
|
|
| |
| int64_t s11 = nb11 / type_size_src1; |
| int64_t s12 = nb12 / type_size_src1; |
| int64_t s13 = nb13 / type_size_src1; |
| ggml_sycl_pool_alloc<sycl::half> src1_f16_alloc(ctx.pool()); |
|
|
| |
| if (src1->type != GGML_TYPE_F16) { |
| scope_op_debug_print scope_dbg_print(__func__, "/to_fp16_nc_sycl", dst, 2, |
| " : converting src1 to fp16"); |
|
|
| |
| int last_dim=0; |
| int last_str=0; |
| size_t largest_str=0; |
| for(int i = 0; i< 4; i++){ |
| |
| if(src1->nb[i] == largest_str){ |
| if(src1->ne[last_dim] == 1){ |
| last_str = i; |
| last_dim = i; |
| } |
| } |
| if(src1->nb[i] > largest_str){ |
| largest_str = src1->nb[i]; |
| last_str = i; |
| last_dim = i; |
| } |
|
|
| } |
| #if GGML_SYCL_DNNL |
| |
| const int64_t ne_src1 = src1->nb[last_str] * src1->ne[last_dim] / type_size_src1; |
| src1_f16_alloc.alloc(ne_src1); |
| const to_fp16_sycl_t to_fp16_sycl = ggml_get_to_fp16_sycl(src1->type, dst); |
| GGML_ASSERT(to_fp16_sycl != nullptr); |
| to_fp16_sycl(src1_f16, src1_f16_alloc.get(), ne_src1, queue); |
| # else |
| const int64_t ne_src1 = ggml_nelements(src1); |
| src1_f16_alloc.alloc(ne_src1); |
| const to_fp16_nc_sycl_t to_fp16_nc_sycl = ggml_get_to_fp16_nc_sycl(src1->type); |
| GGML_ASSERT(to_fp16_nc_sycl != nullptr); |
| to_fp16_nc_sycl(src1_f16, src1_f16_alloc.get(), ne10, ne11, ne12, ne13, s11, s12, s13, queue); |
| #endif |
|
|
| src1_f16 = src1_f16_alloc.get(); |
| s11 = ne10; |
| s12 = ne11 * s11; |
| s13 = ne12 * s12; |
|
|
| is_src1_cont_2 = true; |
| } |
|
|
| ggml_sycl_pool_alloc<sycl::half> dst_f16(ctx.pool()); |
|
|
| dpct::library_data_t mkl_compute_type = dpct::library_data_t::real_float; |
| dpct::library_data_t mkl_data_type = dpct::library_data_t::real_float; |
|
|
| |
| size_t nbd2 = dst->nb[2]; |
| size_t nbd3 = dst->nb[3]; |
|
|
| const float alpha_f32 = 1.0f; |
| const float beta_f32 = 0.0f; |
|
|
| const void * alpha = &alpha_f32; |
| const void * beta = &beta_f32; |
|
|
| GGML_ASSERT(ne12 % ne02 == 0); |
| GGML_ASSERT(ne13 % ne03 == 0); |
| GGML_ASSERT(ne01 == static_cast<int64_t>(nb1/nb0)); |
| GGML_ASSERT(ne10 == ne00); |
|
|
| |
| const int64_t r2 = ne12 / ne02; |
| const int64_t r3 = ne13 / ne03; |
|
|
| #if GGML_SYCL_DNNL |
| if (!g_ggml_sycl_disable_dnn) { |
| int64_t str_a0 = nb00 / type_size_src0; |
| int64_t str_a1 = nb01 / type_size_src0; |
| int64_t str_a2 = nb02 / type_size_src0; |
|
|
| int64_t str_b0 = nb10 / type_size_src1; |
| int64_t str_b1 = nb11 / type_size_src1; |
| int64_t str_b2 = nb12 / type_size_src1; |
|
|
| auto launch_gemm_for_batches = [&ctx, queue](const sycl::half *src0, |
| const sycl::half *src1, float *dst, |
| int64_t a0, int64_t a1, int64_t batcha, |
| int64_t , int64_t b1, int64_t batchb, |
| int64_t sa0, int64_t sa1, int64_t sa2, |
| int64_t sb0, int64_t sb1, int64_t sb2, |
| int64_t sd2) { |
| bool supported_broadcast = batchb == batcha ? true |
| : batchb == 1 || batcha == 1 ? true |
| : false; |
| if (supported_broadcast) { |
| DnnlGemmWrapper::gemm(ctx, a1, b1, a0, src0, |
| DnnlGemmWrapper::to_dt<sycl::half>(), sa0, sa1, sa2, src1, |
| DnnlGemmWrapper::to_dt<sycl::half>(), sb0, sb1, sb2, dst, |
| DnnlGemmWrapper::to_dt<float>(), queue, batcha, batchb); |
| } else { |
| |
| int64_t batches0 = batcha; |
| int64_t batches1 = batchb; |
|
|
| if (batches0 > batches1) { |
| int64_t num_mul_mats = batches1; |
| int64_t sub_batch = batches0 / num_mul_mats; |
| |
| for (int64_t i0 = 0; i0 < num_mul_mats; i0++) { |
| const sycl::half *src0_shifted = src0 + (sa2 * i0 * sub_batch); |
| const sycl::half *src1_shifted = src1 + (sb2 * i0); |
| float *dst_shifted = dst + (sd2 * i0 * sub_batch); |
| DnnlGemmWrapper::gemm(ctx, a1, b1, a0, src0_shifted, |
| DnnlGemmWrapper::to_dt<sycl::half>(), sa0, sa1, sa2, |
| src1_shifted, DnnlGemmWrapper::to_dt<sycl::half>(), sb0, |
| sb1, sb2, dst_shifted, DnnlGemmWrapper::to_dt<float>(), |
| queue, sub_batch, 1); |
| } |
| } else { |
| int64_t num_mul_mats = batches0; |
| int64_t sub_batch = batches1 / num_mul_mats; |
| |
| for (int64_t i1 = 0; i1 < num_mul_mats; i1++) { |
| const sycl::half *src0_shifted = src0 + (sa2 * i1); |
| const sycl::half *src1_shifted = src1 + (sb2 * i1 * sub_batch); |
| float *dst_shifted = dst + (sd2 * i1 * sub_batch); |
| DnnlGemmWrapper::gemm(ctx, a1, b1, a0, src0_shifted, |
| DnnlGemmWrapper::to_dt<sycl::half>(), sa0, sa1, sa2, |
| src1_shifted, DnnlGemmWrapper::to_dt<sycl::half>(), sb0, |
| sb1, sb2, dst_shifted, DnnlGemmWrapper::to_dt<float>(), |
| queue, 1, sub_batch); |
| } |
| } |
| } |
| }; |
|
|
| const bool cont_batches_dim2_a = nb02 * ne02 == nb03; |
| const bool cont_batches_dim2_b = nb12 * ne12 == nb13; |
| const bool cont_batches_dim3_a = ne02 == 1 && nb02 * ne01 == nb03; |
| const bool cont_batches_dim3_b = ne12 == 1 && nb12 * ne11 == nb13; |
| if (cont_batches_dim2_a && cont_batches_dim2_b) { |
| |
| int64_t batches0 = ne02 * ne03; |
| int64_t batches1 = ne12 * ne13; |
| launch_gemm_for_batches(src0_f16, src1_f16, dst_ddf, ne00, ne01, batches0, |
| ne10, ne11, batches1, str_a0, str_a1, str_a2, str_b0, str_b1, |
| str_b2, nb2 / sizeof(float)); |
| } else if (cont_batches_dim3_a && cont_batches_dim3_b) { |
| |
| int64_t batches0 = ne02 * ne03; |
| int64_t batches1 = ne12 * ne13; |
| int64_t str_a3 = nb03 / type_size_src0; |
| int64_t str_b3 = nb13 / type_size_src1; |
| launch_gemm_for_batches(src0_f16, src1_f16, dst_ddf, ne00, ne01, batches0, |
| ne10, ne11, batches1, str_a0, str_a1, str_a3, str_b0, str_b1, |
| str_b3, nb2 / sizeof(float)); |
| } else { |
| for (int64_t b_a = 0; b_a < ne03; b_a++) { |
| const sycl::half *src0_f16_shifted |
| = src0_f16 + (nb03 * b_a / type_size_src0); |
| const sycl::half *src1_f16_shifted |
| = src1_f16 + (nb13 * b_a / type_size_src1); |
| float *dst_shifted = dst_ddf + (nb3 * b_a / sizeof(float)); |
| int64_t batches0 = ne02; |
| int64_t batches1 = ne12; |
| launch_gemm_for_batches(src0_f16_shifted, src1_f16_shifted, dst_shifted, |
| ne00, ne01, batches0, ne10, ne11, batches1, str_a0, str_a1, |
| str_a2, str_b0, str_b1, str_b2, nb2 / sizeof(float)); |
| } |
| } |
|
|
| } |
| else |
| #endif |
| { |
| if (r2 == 1 && r3 == 1 && is_src0_cont_2 && is_src1_cont_2) { |
| |
| const int64_t sma = ne02 == 1 ? nb03/nb00 : nb02/nb00; |
| const int64_t smb = ne12 == 1 ? s13 : s12; |
|
|
| |
| SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch(*queue, oneapi::mkl::transpose::trans, |
| oneapi::mkl::transpose::nontrans, ne01, ne11, ne10, alpha, |
| src0_f16, dpct::library_data_t::real_half, nb01 / nb00, sma, |
| src1_f16, dpct::library_data_t::real_half, s11, smb, beta, dst_ddf, |
| mkl_data_type, ne0, ne1 * ne0, ne12 * ne13, mkl_compute_type))); |
| } else { |
| const int ne23 = ne12 * ne13; |
|
|
| ggml_sycl_pool_alloc<const void *> ptrs_src(ctx.pool(), 2 * ne23); |
| ggml_sycl_pool_alloc<void *> ptrs_dst(ctx.pool(), 1 * ne23); |
| ggml_sycl_pool_alloc<matrix_info_t<float>> matrix_info(ctx.host_pool(), 1); |
|
|
| sycl::range<3> block_dims(1, ne12, ne13); |
| queue->submit([&](sycl::handler & cgh) { |
| const void ** ptrs_src_get = ptrs_src.get(); |
| void ** ptrs_dst_get = ptrs_dst.get(); |
| size_t nb12_scaled = src1->type == GGML_TYPE_F16 ? nb12 : s12 * sizeof(sycl::half); |
| size_t nb13_scaled = src1->type == GGML_TYPE_F16 ? nb13 : s13 * sizeof(sycl::half); |
| cgh.parallel_for(sycl::nd_range<3>(block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) { |
| k_compute_batched_ptrs(src0_f16, src1_f16, dst_ddf, ptrs_src_get, ptrs_dst_get, ne12, ne13, ne23, nb02, |
| nb03, nb12_scaled, nb13_scaled, nbd2, nbd3, r2, r3, item_ct1); |
| }); |
| }); |
|
|
| SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch( |
| *queue, oneapi::mkl::transpose::trans, oneapi::mkl::transpose::nontrans, ne01, ne11, ne10, alpha, |
| (const void **) (ptrs_src.get() + 0 * ne23), dpct::library_data_t::real_half, nb01 / nb00, |
| (const void **) (ptrs_src.get() + 1 * ne23), dpct::library_data_t::real_half, s11, beta, |
| (void **) (ptrs_dst.get() + 0 * ne23), mkl_data_type, ne0, ne23, mkl_compute_type, matrix_info.get()))); |
| } |
| } |
| } catch (const sycl::exception & exc) { |
| std::cerr << exc.what() << "Exception caught at file:" << __FILE__ << ", line:" << __LINE__ << std::endl; |
| std::exit(1); |
| } |
|
|
| enum class mul_mat_algo { |
| DMMV = 0, |
| MMVQ = 1, |
| MUL_MAT_SYCL = 2, |
| }; |
|
|
| inline bool ggml_sycl_supports_mmq(enum ggml_type type) { |
| |
| GGML_UNUSED(type); |
| return false; |
| } |
|
|
| inline bool ggml_sycl_supports_reorder_mul_mat_sycl(enum ggml_type type) { |
| switch (type) { |
| case GGML_TYPE_Q4_0: |
| return true; |
| case GGML_TYPE_Q4_K: |
| case GGML_TYPE_Q6_K: |
| return !g_ggml_sycl_prioritize_dmmv; |
| default: |
| return false; |
| } |
| } |
|
|
| inline bool ggml_sycl_supports_reorder_dmmv(enum ggml_type type) { |
| switch (type) { |
| case GGML_TYPE_Q4_0: |
| return true; |
| default: |
| return false; |
| } |
| } |
|
|
| inline bool ggml_sycl_supports_reorder_mmvq(enum ggml_type type) { |
| switch (type) { |
| case GGML_TYPE_Q4_0: |
| case GGML_TYPE_Q4_K: |
| case GGML_TYPE_Q6_K: |
| return true; |
| default: |
| return false; |
| } |
| } |
|
|
| static bool ggml_sycl_supports_dmmv(enum ggml_type type) { |
| switch (type) { |
| case GGML_TYPE_Q4_0: |
| case GGML_TYPE_Q4_1: |
| case GGML_TYPE_Q5_0: |
| case GGML_TYPE_Q5_1: |
| case GGML_TYPE_Q8_0: |
| case GGML_TYPE_Q2_K: |
| case GGML_TYPE_Q3_K: |
| case GGML_TYPE_Q4_K: |
| case GGML_TYPE_Q5_K: |
| case GGML_TYPE_Q6_K: |
| case GGML_TYPE_F16: |
| return true; |
| default: |
| return false; |
| } |
| } |
|
|
| |
| static inline void * sycl_ext_malloc_device(dpct::queue_ptr stream, size_t size) { |
| bool use_async = g_ggml_sycl_use_async_mem_op; |
| #if defined(GGML_SYCL_GRAPH) && SYCL_EXT_ONEAPI_ASYNC_MEMORY_ALLOC |
| if (use_async) { |
| return syclex::async_malloc(*stream, sycl::usm::alloc::device, size); |
| } |
| #else |
| |
| GGML_ASSERT(!use_async); |
| #endif |
| return sycl::malloc(size, *stream, sycl::usm::alloc::device); |
| } |
|
|
| static inline void sycl_ext_free(dpct::queue_ptr stream, void * ptr) { |
| bool use_async = g_ggml_sycl_use_async_mem_op; |
| #if defined(GGML_SYCL_GRAPH) && SYCL_EXT_ONEAPI_ASYNC_MEMORY_ALLOC |
| if (use_async) { |
| syclex::async_free(*stream, ptr); |
| return; |
| } |
| #else |
| |
| GGML_ASSERT(!use_async); |
| #endif |
| sycl::free(ptr, *stream); |
| } |
|
|
| static void reorder_qw_q4_0(uint8_t * data_device, const int ncols, const int nrows, size_t size, size_t offset, |
| dpct::queue_ptr stream) { |
| uint8_t * tmp_buf = static_cast<uint8_t *>(sycl_ext_malloc_device(stream, size)); |
|
|
| sycl::event copy_event; |
| SYCL_CHECK(CHECK_TRY_ERROR(copy_event = stream->memcpy(tmp_buf, data_device, size))); |
| if (!g_ggml_sycl_use_async_mem_op) { |
| copy_event.wait(); |
| } |
|
|
| GGML_ASSERT((size % sizeof(block_q4_0) == 0)); |
| GGML_ASSERT((offset % sizeof(block_q4_0) == 0)); |
| int offset_blks = offset / sizeof(block_q4_0); |
| auto qs_ptr = data_device + offset_blks * QK4_0 / 2; |
| auto d_ptr = (sycl::half*)(qs_ptr + ncols * nrows / 2) + offset_blks; |
|
|
| auto reorder_event = stream->parallel_for( |
| size / sizeof(block_q4_0), |
| [=](auto i) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { |
| const block_q4_0* x = (const block_q4_0*)tmp_buf; |
| const int ib = i; |
|
|
| for (int j = 0; j < QK4_0/2; j ++) |
| { |
| *(qs_ptr + ib * QK4_0 / 2 + j) = x[ib].qs[j]; |
| } |
| *(d_ptr + ib) = x[ib].d; |
| }); |
| if (!g_ggml_sycl_use_async_mem_op) { |
| reorder_event.wait_and_throw(); |
| } |
| sycl_ext_free(stream, tmp_buf); |
| } |
|
|
| static void reorder_qw_q4_k(uint8_t * data_device, size_t size, size_t offset, dpct::queue_ptr stream) { |
| GGML_ASSERT(size % sizeof(block_q4_K) == 0); |
| GGML_ASSERT(offset % sizeof(block_q4_K) == 0); |
|
|
| const int nblocks = size / sizeof(block_q4_K); |
|
|
| uint8_t * tmp_buf = static_cast<uint8_t *>(sycl_ext_malloc_device(stream, size)); |
|
|
| sycl::event copy_event; |
| SYCL_CHECK(CHECK_TRY_ERROR(copy_event = stream->memcpy(tmp_buf, data_device, size))); |
| if (!g_ggml_sycl_use_async_mem_op) { |
| copy_event.wait(); |
| } |
|
|
| auto * qs_ptr = data_device; |
| auto * scales_ptr = qs_ptr + QK_K / 2 * nblocks; |
| auto * dm_ptr = (sycl::half2 *) (scales_ptr + K_SCALE_SIZE * nblocks); |
|
|
| auto reorder_event = stream->parallel_for(nblocks, [=](auto i) { |
| const block_q4_K * x = (const block_q4_K *) tmp_buf; |
| const int ib = i; |
|
|
| for (int j = 0; j < QK_K / 2; ++j) { |
| qs_ptr[ib * (QK_K / 2) + j] = x[ib].qs[j]; |
| } |
|
|
| for (int j = 0; j < K_SCALE_SIZE; ++j) { |
| scales_ptr[ib * K_SCALE_SIZE + j] = x[ib].scales[j]; |
| } |
|
|
| dm_ptr[ib] = x[ib].dm; |
| }); |
| if (!g_ggml_sycl_use_async_mem_op) { |
| reorder_event.wait_and_throw(); |
| } |
| sycl_ext_free(stream, tmp_buf); |
| } |
|
|
| static void reorder_qw_q6_k(uint8_t * data_device, size_t size, size_t offset, dpct::queue_ptr stream) { |
| GGML_ASSERT(size % sizeof(block_q6_K) == 0); |
| GGML_ASSERT(offset % sizeof(block_q6_K) == 0); |
|
|
| const int nblocks = size / sizeof(block_q6_K); |
|
|
| uint8_t * tmp_buf = static_cast<uint8_t *>(sycl_ext_malloc_device(stream, size)); |
|
|
| sycl::event copy_event; |
| SYCL_CHECK(CHECK_TRY_ERROR(copy_event = stream->memcpy(tmp_buf, data_device, size))); |
| if (!g_ggml_sycl_use_async_mem_op) { |
| copy_event.wait(); |
| } |
|
|
| auto * ql_ptr = data_device; |
| auto * qh_ptr = ql_ptr + (QK_K / 2) * nblocks; |
| auto * scales_ptr = qh_ptr + (QK_K / 4) * nblocks; |
| sycl::half * dm_ptr = (sycl::half *) (scales_ptr + (QK_K / 16) * nblocks); |
|
|
| auto reorder_event = stream->parallel_for(nblocks, [=](auto i) { |
| const block_q6_K * x = (const block_q6_K *) tmp_buf; |
| const int ib = i; |
|
|
| const uint8_t * ql = x[ib].ql; |
| const uint8_t * qh = x[ib].qh; |
| uint8_t * base_ql_ptr = ql_ptr + (QK_K / 2) * ib; |
| uint8_t * base_qh_ptr = qh_ptr + (QK_K / 4) * ib; |
| uint8_t * base_scales_ptr = scales_ptr + (QK_K / 16) * ib; |
|
|
| for (int j = 0; j < QK_K / 2; ++j) { |
| base_ql_ptr[j] = ql[j]; |
| } |
| for (int j = 0; j < QK_K / 4; ++j) { |
| base_qh_ptr[j] = qh[j]; |
| } |
|
|
| for (int j = 0; j < QK_K / 16; ++j) { |
| base_scales_ptr[j] = x[ib].scales[j]; |
| } |
|
|
| dm_ptr[ib] = x[ib].d; |
| }); |
| if (!g_ggml_sycl_use_async_mem_op) { |
| reorder_event.wait_and_throw(); |
| } |
| sycl_ext_free(stream, tmp_buf); |
| } |
|
|
| static void reorder_qw(const ggml_tensor * src0, dpct::queue_ptr stream) { |
| uint8_t * data_device = (uint8_t *) src0->data; |
| size_t ncols = src0->ne[0]; |
| size_t nrows = src0->ne[1]; |
| size_t size = ggml_nbytes(src0); |
|
|
| switch (src0->type) { |
| case GGML_TYPE_Q4_0: |
| reorder_qw_q4_0(data_device, ncols, nrows, size, 0, stream); |
| break; |
| case GGML_TYPE_Q4_K: |
| reorder_qw_q4_k(data_device, size, 0, stream); |
| break; |
| case GGML_TYPE_Q6_K: |
| reorder_qw_q6_k(data_device, size, 0, stream); |
| break; |
| default: |
| GGML_ABORT("reorder_qw() called with unsupported type"); |
| break; |
| } |
| } |
|
|
| static bool should_reorder_tensor(ggml_backend_sycl_context& ctx, const ggml_tensor * dst) { |
| return !g_ggml_sycl_disable_optimize && |
| ctx.opt_feature.reorder && |
| dst->op == GGML_OP_MUL_MAT && |
| dst->src[1]->ne[1]==1 && dst->src[1]->ne[2]==1 && dst->src[1]->ne[3]==1; |
| } |
|
|
| static void opt_for_reorder(ggml_backend_sycl_context * ctx, const ggml_tensor * src0, const ggml_tensor * , |
| ggml_tensor * dst, mul_mat_algo mm_algorithm) { |
| if (!should_reorder_tensor(*ctx, dst)) { |
| return; |
| } |
|
|
| ggml_tensor_extra_gpu * extra = static_cast<ggml_tensor_extra_gpu *>(src0->extra); |
| if (!extra || extra->optimized_feature.reorder) { |
| return; |
| } |
|
|
| switch (mm_algorithm) { |
| case mul_mat_algo::DMMV: |
| if (!ggml_sycl_supports_reorder_dmmv(src0->type)) { |
| return; |
| } |
| break; |
| case mul_mat_algo::MMVQ: |
| if (!ggml_sycl_supports_reorder_mmvq(src0->type)) { |
| return; |
| } |
| break; |
| case mul_mat_algo::MUL_MAT_SYCL: |
| if (!ggml_sycl_supports_reorder_mul_mat_sycl(src0->type)) { |
| return; |
| } |
| break; |
| } |
|
|
| reorder_qw(src0, ctx->stream()); |
| extra->optimized_feature.reorder = true; |
| } |
|
|
|
|
| static bool can_use_dequantize_mul_mat_vec(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { |
| return ggml_sycl_supports_dmmv(src0->type) && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 && |
| src0->ne[0] % GGML_SYCL_DMMV_X == 0 && src1->ne[1] == 1; |
| } |
|
|
| static bool can_use_mul_mat_vec_q(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { |
| return ggml_is_quantized(src0->type) && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 && |
| src1->ne[1] <= MMVQ_MAX_BATCH_SIZE; |
| } |
|
|
| static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { |
| scope_op_debug_print scope_dbg_print(__func__, dst, 2); |
| const bool split = ggml_backend_buffer_is_sycl_split(src0->buffer); |
| int64_t min_compute_capability = INT_MAX; |
|
|
| if (split) { |
| ggml_backend_sycl_split_buffer_type_context * buft_ctx = |
| (ggml_backend_sycl_split_buffer_type_context *) src0->buffer->buft->context; |
| auto & tensor_split = buft_ctx->tensor_split; |
| for (int id = 0; id < ggml_sycl_info().device_count; ++id) { |
| |
| if (tensor_split[id] >= (id + 1 < ggml_sycl_info().device_count ? tensor_split[id + 1] : 1.0f)) { |
| continue; |
| } |
|
|
| if (min_compute_capability > ggml_sycl_info().devices[id].cc) { |
| min_compute_capability = ggml_sycl_info().devices[id].cc; |
| } |
| } |
| } else { |
| min_compute_capability = ggml_sycl_info().devices[ctx.device].cc; |
| } |
|
|
| |
| bool use_dequantize_mul_mat_vec = can_use_dequantize_mul_mat_vec(src0, src1, dst); |
|
|
| bool use_mul_mat_vec_q = can_use_mul_mat_vec_q(src0, src1, dst); |
|
|
| bool use_mul_mat_q = ggml_sycl_supports_mmq(src0->type) |
| && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32; |
|
|
|
|
| |
| |
| use_mul_mat_q = use_mul_mat_q && (src0->type != GGML_TYPE_IQ2_XXS); |
| #ifdef SYCL_USE_XMX |
| use_mul_mat_q = use_mul_mat_q && (src1->ne[1] <= MMQ_MAX_BATCH_SIZE); |
| #endif |
|
|
| |
| |
| |
| if (!g_ggml_sycl_prioritize_dmmv && ((should_reorder_tensor(ctx, dst) && |
| ggml_sycl_supports_reorder_mmvq(src0->type)))) { |
| use_dequantize_mul_mat_vec = use_dequantize_mul_mat_vec && !use_mul_mat_vec_q; |
| } |
|
|
| if (!split && src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && src1->ne[1] == 1) { |
| |
| if (src0->ne[3] == 1 && src1->ne[3] == 1) { |
| |
| |
| ggml_sycl_mul_mat_vec_p021(ctx, src0, src1, dst); |
| } else { |
| |
| ggml_sycl_mul_mat_batched_sycl(ctx, src0, src1, dst); |
| } |
| } else if (!split && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && src1->ne[1] == 1 && src1->ne[3] == 1) { |
| |
| ggml_sycl_mul_mat_vec_nc(ctx, src0, src1, dst); |
| } else if (!split && src0->type == GGML_TYPE_F16 && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2] * src1->ne[3] > 1) { |
| |
| ggml_sycl_mul_mat_batched_sycl(ctx, src0, src1, dst); |
| } else if (use_dequantize_mul_mat_vec) { |
| opt_for_reorder(&ctx, src0, src1, dst, mul_mat_algo::DMMV); |
| ggml_sycl_op_mul_mat<no_quantize_q8_1>(ctx, src0, src1, dst, ggml_sycl_op_dequantize_mul_mat_vec); |
| } else if (use_mul_mat_vec_q) { |
| opt_for_reorder(&ctx, src0, src1, dst, mul_mat_algo::MMVQ); |
| ggml_tensor_extra_gpu * extra = static_cast<ggml_tensor_extra_gpu *>(src0->extra); |
| if (extra && extra->optimized_feature.reorder) { |
| ggml_sycl_op_mul_mat<quantize_and_reorder_q8_1_soa>(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_vec_q); |
| } else { |
| ggml_sycl_op_mul_mat<quantize_q8_1>(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_vec_q); |
| } |
| } else if (use_mul_mat_q) { |
| ggml_sycl_op_mul_mat<quantize_q8_1>(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_q); |
| } else { |
| ggml_sycl_op_mul_mat<no_quantize_q8_1>(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_sycl); |
| } |
| } |
|
|
|
|
| struct mmid_row_mapping { |
| int32_t i1; |
| int32_t i2; |
| }; |
|
|
| __dpct_inline__ static void k_copy_src1_to_contiguous( |
| const char *__restrict__ src1_original, char *__restrict__ src1_contiguous, |
| int *__restrict__ cur_src1_row, mmid_row_mapping *__restrict__ row_mapping, |
| const char *__restrict ids, int64_t i02, size_t ids_nb1, size_t ids_nb0, |
| int64_t ne11, int64_t ne10, size_t nb11, size_t nb12, |
| const sycl::nd_item<3> &item_ct1, int &src1_row) { |
| int32_t iid1 = item_ct1.get_group(2); |
| int32_t id = item_ct1.get_group(1); |
|
|
| const int32_t row_id_i = *(const int32_t *) (ids + iid1*ids_nb1 + id*ids_nb0); |
|
|
| if (row_id_i != i02) { |
| return; |
| } |
|
|
| const int64_t i11 = id % ne11; |
| const int64_t i12 = iid1; |
|
|
| if (item_ct1.get_local_id(2) == 0) { |
| src1_row = |
| dpct::atomic_fetch_add<sycl::access::address_space::generic_space>( |
| cur_src1_row, 1); |
| row_mapping[src1_row] = {id, iid1}; |
| } |
| |
| |
| |
| |
| |
| item_ct1.barrier(); |
|
|
| const float * src1_row_original = (const float *)(src1_original + i11*nb11 + i12*nb12); |
| float * src1_row_contiguous = (float *)(src1_contiguous + src1_row*nb11); |
|
|
| #pragma unroll |
| for (int i = item_ct1.get_local_id(2); i < ne10; |
| i += item_ct1.get_local_range(2)) { |
| src1_row_contiguous[i] = src1_row_original[i]; |
| } |
| } |
|
|
| __dpct_inline__ static void k_copy_dst_from_contiguous( |
| char *__restrict__ dst_original, const char *__restrict__ dst_contiguous, |
| const mmid_row_mapping *__restrict__ row_mapping, int64_t ne0, size_t nb1, |
| size_t nb2, const sycl::nd_item<3> &item_ct1) { |
| int32_t i = item_ct1.get_group(2); |
|
|
| const int32_t i1 = row_mapping[i].i1; |
| const int32_t i2 = row_mapping[i].i2; |
|
|
| const float * dst_row_contiguous = (const float *)(dst_contiguous + i*nb1); |
| float * dst_row_original = (float *)(dst_original + i1*nb1 + i2*nb2); |
|
|
| #pragma unroll |
| for (int j = item_ct1.get_local_id(2); j < ne0; |
| j += item_ct1.get_local_range(2)) { |
| dst_row_original[j] = dst_row_contiguous[j]; |
| } |
| } |
|
|
| static void ggml_sycl_mul_mat_id(ggml_backend_sycl_context & ctx, |
| ggml_tensor *dst) try { |
| scope_op_debug_print scope_dbg_print(__func__, dst, 3); |
| const ggml_tensor *src0 = dst->src[0]; |
| const ggml_tensor *src1 = dst->src[1]; |
| GGML_ASSERT(!ggml_backend_buffer_is_sycl_split(src0->buffer) && "mul_mat_id does not support split buffers"); |
|
|
| const ggml_tensor *ids = dst->src[2]; |
| GGML_TENSOR_BINARY_OP_LOCALS |
|
|
| const queue_ptr stream = ctx.stream(); |
|
|
| const int64_t n_as = ne02; |
| const int64_t n_ids = ids->ne[0]; |
|
|
| std::vector<char> ids_host(ggml_nbytes(ids)); |
| const char * ids_dev = (const char *) ids->data; |
|
|
| SYCL_CHECK(CHECK_TRY_ERROR( |
| stream->memcpy(ids_host.data(), ids_dev, ggml_nbytes(ids)))); |
| SYCL_CHECK(CHECK_TRY_ERROR(stream->wait())); |
|
|
| ggml_tensor src0_row = *src0; |
| ggml_tensor src1_row = *src1; |
| ggml_tensor dst_row = *dst; |
|
|
| char *src0_original = (char *)src0->data; |
| char *src1_original = (char *)src1->data; |
| char *dst_original = (char *)dst->data; |
|
|
| src0_row.ne[2] = 1; |
| src0_row.ne[3] = 1; |
| src0_row.nb[3] = nb02; |
|
|
| src1_row.ne[1] = 1; |
| src1_row.ne[2] = 1; |
| src1_row.ne[3] = 1; |
| src1_row.nb[2] = nb11; |
| src1_row.nb[3] = nb11; |
|
|
| dst_row.ne[1] = 1; |
| dst_row.ne[2] = 1; |
| dst_row.ne[3] = 1; |
| dst_row.nb[2] = nb1; |
| dst_row.nb[3] = nb1; |
| if (ne12 == 1) { |
| for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) { |
| for (int64_t id = 0; id < n_ids; id++) { |
| const int32_t i02 = *(const int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]); |
| GGML_ASSERT(i02 >= 0 && i02 < n_as); |
|
|
| const int64_t i11 = id % ne11; |
| const int64_t i12 = iid1; |
|
|
| const int64_t i1 = id; |
| const int64_t i2 = i12; |
|
|
| src0_row.data = src0_original + i02*nb02; |
| src1_row.data = src1_original + i11*nb11 + i12*nb12; |
| dst_row.data = dst_original + i1*nb1 + i2*nb2; |
|
|
| ggml_sycl_mul_mat(ctx, &src0_row, &src1_row, &dst_row); |
| } |
| } |
| } else { |
| ggml_sycl_pool_alloc<char> src1_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(src1)); |
| ggml_sycl_pool_alloc<char> dst_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(dst)); |
|
|
| src1_row.data = src1_contiguous.get(); |
| dst_row.data = dst_contiguous.get(); |
|
|
| for (int64_t i02 = 0; i02 < n_as; i02++) { |
| int64_t num_src1_rows = 0; |
| for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) { |
| for (int64_t id = 0; id < n_ids; id++) { |
| const int32_t row_id_i = *(const int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]); |
|
|
| GGML_ASSERT(row_id_i >= 0 && row_id_i < n_as); |
|
|
| if (row_id_i != i02) { |
| continue; |
| } |
|
|
| num_src1_rows++; |
| } |
| } |
|
|
| if (num_src1_rows == 0) { |
| continue; |
| } |
|
|
|
|
| ggml_sycl_pool_alloc<int> dev_cur_src1_row(ctx.pool(), 1); |
| ggml_sycl_pool_alloc<mmid_row_mapping> dev_row_mapping(ctx.pool(), num_src1_rows); |
| SYCL_CHECK(CHECK_TRY_ERROR( |
| stream->memset(dev_cur_src1_row.get(), 0, sizeof(int)))); |
|
|
| const unsigned int max_work_group_size = ggml_sycl_info().max_work_group_sizes[ctx.device]; |
| assert(max_work_group_size % (WARP_SIZE * WARP_SIZE) == 0); |
|
|
| { |
| sycl::range<3> block_dims(1, 1, std::min((unsigned int)ne10, max_work_group_size)); |
| sycl::range<3> grid_dims(1, n_ids, ids->ne[1]); |
| stream->submit([&](sycl::handler &cgh) { |
| sycl::local_accessor<int, 0> src1_row_acc(cgh); |
|
|
| char *__restrict src1_contiguous_get = |
| src1_contiguous.get(); |
| int *__restrict dev_cur_src1_row_get = |
| dev_cur_src1_row.get(); |
| mmid_row_mapping *__restrict dev_row_mapping_get = |
| dev_row_mapping.get(); |
| size_t ids_nb_ct6 = ids->nb[1]; |
| size_t ids_nb_ct7 = ids->nb[0]; |
|
|
| cgh.parallel_for( |
| sycl::nd_range<3>(grid_dims * block_dims, block_dims), |
| [=](sycl::nd_item<3> item_ct1) { |
| k_copy_src1_to_contiguous( |
| src1_original, src1_contiguous_get, |
| dev_cur_src1_row_get, |
| dev_row_mapping_get, ids_dev, i02, |
| ids_nb_ct6, ids_nb_ct7, ne11, ne10, nb11, nb12, |
| item_ct1, src1_row_acc); |
| }); |
| }); |
| } |
|
|
| src0_row.data = src0_original + i02*nb02; |
|
|
| GGML_ASSERT(nb11 == sizeof(float)*ne10); |
| GGML_ASSERT(nb1 == sizeof(float)*ne0); |
| src1_row.ne[1] = num_src1_rows; |
|
|
| src1_row.nb[1] = nb11; |
| src1_row.nb[2] = num_src1_rows*nb11; |
| src1_row.nb[3] = num_src1_rows*nb11; |
|
|
| dst_row.ne[1] = num_src1_rows; |
| dst_row.nb[1] = nb1; |
| dst_row.nb[2] = num_src1_rows*nb1; |
| dst_row.nb[3] = num_src1_rows*nb1; |
|
|
| ggml_sycl_mul_mat(ctx, &src0_row, &src1_row, &dst_row); |
|
|
| { |
| sycl::range<3> block_dims(1, 1, std::min((unsigned int)ne0, max_work_group_size)); |
| sycl::range<3> grid_dims(1, 1, num_src1_rows); |
| stream->submit([&](sycl::handler &cgh) { |
| const char *__restrict dst_contiguous_get = |
| dst_contiguous.get(); |
| const mmid_row_mapping *__restrict dev_row_mapping_get = |
| dev_row_mapping.get(); |
|
|
| cgh.parallel_for( |
| sycl::nd_range<3>(grid_dims * block_dims, block_dims), |
| [=](sycl::nd_item<3> item_ct1) { |
| k_copy_dst_from_contiguous(dst_original, |
| dst_contiguous_get, |
| dev_row_mapping_get, |
| ne0, nb1, nb2, item_ct1); |
| }); |
| }); |
| } |
| } |
| } |
| } |
| catch (sycl::exception const &exc) { |
| std::cerr << exc.what() << "Exception caught at file:" << __FILE__ |
| << ", line:" << __LINE__ << std::endl; |
| std::exit(1); |
| } |
|
|
| static void ggml_sycl_scale(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { |
| scope_op_debug_print scope_dbg_print(__func__, dst, 1); |
| ggml_sycl_op_scale(ctx, dst); |
| } |
|
|
| static void ggml_sycl_diag_mask_inf(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { |
| scope_op_debug_print scope_dbg_print(__func__, dst, 1); |
| ggml_sycl_op_diag_mask_inf(ctx, dst); |
| } |
|
|
| static void ggml_sycl_pool2d(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { |
| scope_op_debug_print scope_dbg_print(__func__, dst, 1); |
| ggml_sycl_op_pool2d(ctx, dst); |
| } |
|
|
| static void ggml_sycl_im2col(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { |
| scope_op_debug_print scope_dbg_print(__func__, dst, 2); |
| ggml_sycl_op_im2col(ctx, dst); |
| } |
|
|
| static void ggml_sycl_sum(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { |
| scope_op_debug_print scope_dbg_print(__func__, dst, 1); |
| GGML_ASSERT(ggml_is_contiguous(dst->src[0])); |
| ggml_sycl_op_sum(ctx, dst); |
| } |
|
|
| static void ggml_sycl_sum_rows(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { |
| scope_op_debug_print scope_dbg_print(__func__, dst, 1); |
| GGML_ASSERT(ggml_is_contiguous(dst->src[0])); |
| ggml_sycl_op_sum_rows(ctx, dst); |
| } |
|
|
| static void ggml_sycl_mean(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { |
| scope_op_debug_print scope_dbg_print(__func__, dst, 1); |
| GGML_ASSERT(ggml_is_contiguous(dst->src[0])); |
| ggml_sycl_op_mean(ctx, dst); |
| } |
|
|
| static void ggml_sycl_argsort(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { |
| scope_op_debug_print scope_dbg_print(__func__, dst, 1); |
| GGML_ASSERT(ggml_is_contiguous(dst->src[0])); |
| ggml_sycl_op_argsort(ctx, dst); |
| } |
|
|
| static void ggml_sycl_argmax(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { |
| scope_op_debug_print scope_dbg_print(__func__, dst, 1); |
| GGML_ASSERT(ggml_is_contiguous(dst->src[0])); |
| ggml_sycl_op_argmax(ctx, dst); |
| } |
|
|
|
|
| static void ggml_sycl_set_main_device(const int main_device) try { |
| if (dpct::get_current_device_id() == static_cast<unsigned int> (main_device)) { |
| return; |
| } |
| check_allow_gpu_index(main_device); |
| dpct::select_device(main_device); |
|
|
| if (g_ggml_sycl_debug) { |
| dpct::device_info prop; |
| SYCL_CHECK(CHECK_TRY_ERROR(dpct::get_device_info( |
| prop, dpct::dev_mgr::instance().get_device(main_device)))); |
| GGML_LOG_INFO("Using device %d (%s) as main device\n", |
| main_device, prop.get_name()); |
| } |
| } |
| catch (sycl::exception const &exc) { |
| std::cerr << exc.what() << "Exception caught at file:" << __FILE__ |
| << ", line:" << __LINE__ << std::endl; |
| std::exit(1); |
| } |
|
|
| static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct ggml_tensor * dst) try { |
| if (!g_sycl_loaded) return false; |
|
|
| if (dst->src[0] != nullptr && ggml_backend_buffer_is_sycl_split(dst->src[0]->buffer)) { |
| ggml_sycl_set_peer_access(dst->src[1]->ne[1], ctx.device); |
| } |
|
|
| switch (dst->op) { |
| case GGML_OP_ARGMAX: |
| ggml_sycl_argmax(ctx, dst); |
| break; |
| case GGML_OP_CONV_TRANSPOSE_1D: |
| ggml_sycl_op_conv_transpose_1d(ctx, dst); |
| break; |
| case GGML_OP_REPEAT: |
| ggml_sycl_repeat(ctx, dst); |
| break; |
| case GGML_OP_REPEAT_BACK: |
| ggml_sycl_repeat_back(ctx, dst); |
| break; |
| case GGML_OP_GET_ROWS: |
| ggml_sycl_get_rows(ctx, dst); |
| break; |
| case GGML_OP_SET: |
| ggml_sycl_op_set(ctx, dst); |
| break; |
| case GGML_OP_SET_ROWS: |
| ggml_sycl_op_set_rows(ctx, dst); |
| break; |
| case GGML_OP_DUP: |
| ggml_sycl_dup(ctx, dst); |
| break; |
| case GGML_OP_ADD: |
| case GGML_OP_ADD1: |
| ggml_sycl_add(ctx, dst); |
| break; |
| case GGML_OP_ADD_ID: |
| ggml_sycl_add_id(ctx, dst); |
| break; |
| case GGML_OP_SUB: |
| ggml_sycl_sub(ctx, dst); |
| break; |
| case GGML_OP_COUNT_EQUAL: |
| ggml_sycl_count_equal(ctx, dst); |
| break; |
| case GGML_OP_ACC: |
| ggml_sycl_acc(ctx, dst); |
| break; |
| case GGML_OP_MUL: |
| ggml_sycl_mul(ctx, dst); |
| break; |
| case GGML_OP_LOG: |
| ggml_sycl_log(ctx, dst); |
| break; |
| case GGML_OP_DIV: |
| ggml_sycl_div(ctx, dst); |
| break; |
| case GGML_OP_UNARY: |
| switch (ggml_get_unary_op(dst)) { |
| case GGML_UNARY_OP_NEG: |
| ggml_sycl_neg(ctx, dst); |
| break; |
| case GGML_UNARY_OP_STEP: |
| ggml_sycl_step(ctx, dst); |
| break; |
| case GGML_UNARY_OP_GELU: |
| ggml_sycl_gelu(ctx, dst); |
| break; |
| case GGML_UNARY_OP_SILU: |
| ggml_sycl_silu(ctx, dst); |
| break; |
| case GGML_UNARY_OP_GELU_QUICK: |
| ggml_sycl_gelu_quick(ctx, dst); |
| break; |
| case GGML_UNARY_OP_GELU_ERF: |
| ggml_sycl_gelu_erf(ctx, dst); |
| break; |
| case GGML_UNARY_OP_TANH: |
| ggml_sycl_tanh(ctx, dst); |
| break; |
| case GGML_UNARY_OP_RELU: |
| ggml_sycl_relu(ctx, dst); |
| break; |
| case GGML_UNARY_OP_SIGMOID: |
| ggml_sycl_sigmoid(ctx, dst); |
| break; |
| case GGML_UNARY_OP_HARDSIGMOID: |
| ggml_sycl_hardsigmoid(ctx, dst); |
| break; |
| case GGML_UNARY_OP_HARDSWISH: |
| ggml_sycl_hardswish(ctx, dst); |
| break; |
| case GGML_UNARY_OP_EXP: |
| ggml_sycl_exp(ctx, dst); |
| break; |
| case GGML_UNARY_OP_SOFTPLUS: |
| ggml_sycl_softplus(ctx, dst); |
| break; |
| case GGML_UNARY_OP_SGN: |
| ggml_sycl_sgn(ctx, dst); |
| break; |
| case GGML_UNARY_OP_ABS: |
| ggml_sycl_abs(ctx, dst); |
| break; |
| case GGML_UNARY_OP_ELU: |
| ggml_sycl_elu(ctx, dst); |
| break; |
| case GGML_UNARY_OP_FLOOR: |
| ggml_sycl_floor(ctx, dst); |
| break; |
| case GGML_UNARY_OP_CEIL: |
| ggml_sycl_ceil(ctx, dst); |
| break; |
| case GGML_UNARY_OP_ROUND: |
| ggml_sycl_round(ctx, dst); |
| break; |
| case GGML_UNARY_OP_TRUNC: |
| ggml_sycl_trunc(ctx, dst); |
| break; |
| default: |
| return false; |
| } |
| break; |
| case GGML_OP_GLU: |
| switch (ggml_get_glu_op(dst)) { |
| case GGML_GLU_OP_REGLU: |
| ggml_sycl_reglu(ctx, dst); |
| break; |
| case GGML_GLU_OP_GEGLU: |
| ggml_sycl_geglu(ctx, dst); |
| break; |
| case GGML_GLU_OP_SWIGLU: |
| ggml_sycl_swiglu(ctx, dst); |
| break; |
| case GGML_GLU_OP_SWIGLU_OAI: |
| ggml_sycl_swiglu_oai(ctx, dst); |
| break; |
| case GGML_GLU_OP_GEGLU_ERF: |
| ggml_sycl_geglu_erf(ctx, dst); |
| break; |
| case GGML_GLU_OP_GEGLU_QUICK: |
| ggml_sycl_geglu_quick(ctx, dst); |
| break; |
| default: |
| return false; |
| } |
| break; |
| case GGML_OP_NORM: |
| ggml_sycl_norm(ctx, dst); |
| break; |
| case GGML_OP_GROUP_NORM: |
| ggml_sycl_group_norm(ctx, dst); |
| break; |
| case GGML_OP_CONCAT: |
| ggml_sycl_op_concat(ctx, dst); |
| break; |
| case GGML_OP_PAD_REFLECT_1D: |
| ggml_sycl_op_pad_reflect_1d(ctx,dst); |
| break; |
| case GGML_OP_UPSCALE: |
| ggml_sycl_upscale(ctx, dst); |
| break; |
| case GGML_OP_PAD: |
| ggml_sycl_pad(ctx, dst); |
| break; |
| case GGML_OP_LEAKY_RELU: |
| ggml_sycl_leaky_relu(ctx, dst); |
| break; |
| case GGML_OP_RMS_NORM_BACK: |
| ggml_sycl_rms_norm_back(ctx, dst); |
| break; |
| case GGML_OP_RMS_NORM: |
| ggml_sycl_rms_norm(ctx, dst); |
| break; |
| case GGML_OP_L2_NORM: |
| ggml_sycl_l2_norm(ctx, dst); |
| break; |
| case GGML_OP_MUL_MAT: |
| if (dst->src[0]->ne[3] != dst->src[1]->ne[3]) { |
| return false; |
| } |
| |
| ggml_sycl_mul_mat(ctx, dst->src[0], dst->src[1], dst); |
| break; |
| case GGML_OP_MUL_MAT_ID: |
| if (dst->src[0]->ne[3] != dst->src[1]->ne[3]) { |
| return false; |
| } |
| ggml_sycl_mul_mat_id(ctx, dst); |
| break; |
| case GGML_OP_OUT_PROD: |
| ggml_sycl_op_out_prod(ctx, dst); |
| break; |
| case GGML_OP_SCALE: |
| ggml_sycl_scale(ctx, dst); |
| break; |
| case GGML_OP_SQR: |
| ggml_sycl_sqr(ctx, dst); |
| break; |
| case GGML_OP_SQRT: |
| ggml_sycl_sqrt(ctx, dst); |
| break; |
| case GGML_OP_SIN: |
| ggml_sycl_sin(ctx, dst); |
| break; |
| case GGML_OP_COS: |
| ggml_sycl_cos(ctx, dst); |
| break; |
| case GGML_OP_CLAMP: |
| ggml_sycl_clamp(ctx, dst); |
| break; |
| case GGML_OP_CPY: |
| ggml_sycl_cpy(ctx, dst->src[0], dst->src[1]); |
| break; |
| case GGML_OP_CONT: |
| ggml_sycl_dup(ctx, dst); |
| break; |
| case GGML_OP_NONE: |
| case GGML_OP_RESHAPE: |
| case GGML_OP_VIEW: |
| case GGML_OP_PERMUTE: |
| case GGML_OP_TRANSPOSE: |
| GGML_SYCL_DEBUG("%s: Tensor NO-OP\n", __func__); |
| break; |
| case GGML_OP_TRI: |
| ggml_sycl_op_tri(ctx, dst); |
| break; |
| case GGML_OP_DIAG_MASK_INF: |
| ggml_sycl_diag_mask_inf(ctx, dst); |
| break; |
| case GGML_OP_SOFT_MAX: |
| ggml_sycl_op_soft_max(ctx, dst); |
| break; |
| case GGML_OP_SOFT_MAX_BACK: |
| ggml_sycl_op_soft_max_back(ctx, dst); |
| break; |
| case GGML_OP_ROPE: |
| ggml_sycl_rope(ctx, dst); |
| break; |
| case GGML_OP_ROPE_BACK: |
| ggml_sycl_rope_back(ctx, dst); |
| break; |
| case GGML_OP_IM2COL: |
| ggml_sycl_im2col(ctx, dst); |
| break; |
| case GGML_OP_POOL_2D: |
| ggml_sycl_pool2d(ctx, dst); |
| break; |
| case GGML_OP_SUM: |
| ggml_sycl_sum(ctx, dst); |
| break; |
| case GGML_OP_SUM_ROWS: |
| ggml_sycl_sum_rows(ctx, dst); |
| break; |
| case GGML_OP_MEAN: |
| ggml_sycl_mean(ctx, dst); |
| break; |
| case GGML_OP_ARGSORT: |
| ggml_sycl_argsort(ctx, dst); |
| break; |
| case GGML_OP_TOP_K: |
| ggml_sycl_op_top_k(ctx, dst); |
| break; |
| case GGML_OP_TIMESTEP_EMBEDDING: |
| ggml_sycl_op_timestep_embedding(ctx, dst); |
| break; |
| case GGML_OP_RWKV_WKV6: |
| ggml_sycl_op_rwkv_wkv6(ctx, dst); |
| break; |
| case GGML_OP_RWKV_WKV7: |
| ggml_sycl_op_rwkv_wkv7(ctx, dst); |
| break; |
| case GGML_OP_GATED_LINEAR_ATTN: |
| ggml_sycl_op_gated_linear_attn(ctx, dst); |
| break; |
| case GGML_OP_GATED_DELTA_NET: |
| ggml_sycl_gated_delta_net(ctx, dst); |
| break; |
| case GGML_OP_SSM_CONV: |
| ggml_sycl_ssm_conv(ctx, dst); |
| break; |
| case GGML_OP_ROLL: |
| ggml_sycl_roll(ctx, dst); |
| break; |
| case GGML_OP_ARANGE: |
| ggml_sycl_arange(ctx, dst); |
| break; |
| case GGML_OP_FLASH_ATTN_EXT: |
| ggml_sycl_flash_attn_ext(ctx, dst); |
| break; |
| default: |
| return false; |
| } |
|
|
| return true; |
| } catch (sycl::exception & e) { |
| std::cerr << e.what() << "Exception caught at file:" << __FILE__ << ", line:" << __LINE__ << std::endl; |
| std::cerr << "Error OP "<<ggml_op_name(dst->op)<< std::endl; |
| std::exit(1); |
| } |
|
|
| GGML_API void ggml_backend_sycl_get_device_description(int device, char *description, |
| size_t description_size) try { |
| GGML_SYCL_DEBUG("[SYCL] call ggml_backend_sycl_get_device_description\n"); |
| dpct::device_info prop; |
| SYCL_CHECK(CHECK_TRY_ERROR(dpct::get_device_info( |
| prop, dpct::dev_mgr::instance().get_device(device)))); |
| snprintf(description, description_size, "%s", prop.get_name()); |
| } |
| catch (sycl::exception const &exc) { |
| std::cerr << exc.what() << "Exception caught at file:" << __FILE__ |
| << ", line:" << __LINE__ << std::endl; |
| std::exit(1); |
| } |
|
|
| void ggml_backend_sycl_get_device_memory(int device, size_t *free, |
| size_t *total) try { |
| GGML_SYCL_DEBUG("[SYCL] call ggml_backend_sycl_get_device_memory\n"); |
| ggml_sycl_set_device(device); |
|
|
| SYCL_CHECK(CHECK_TRY_ERROR( |
| dpct::dev_mgr::instance().get_device(device).get_memory_info(*free, *total))); |
| } |
| catch (sycl::exception const &exc) { |
| std::cerr << exc.what() << "Exception caught at file:" << __FILE__ |
| << ", line:" << __LINE__ << std::endl; |
| std::exit(1); |
| } |
|
|
| |
|
|
| |
|
|
| static const char * ggml_backend_sycl_get_name(ggml_backend_t backend) { |
|
|
| ggml_backend_sycl_context * sycl_ctx = (ggml_backend_sycl_context *)backend->context; |
|
|
| return sycl_ctx->name.c_str(); |
| } |
|
|
| static void ggml_backend_sycl_free(ggml_backend_t backend) { |
| ggml_backend_sycl_context * sycl_ctx = (ggml_backend_sycl_context *)backend->context; |
|
|
| delete sycl_ctx; |
| delete backend; |
| } |
|
|
| static void ggml_backend_sycl_set_tensor_async(ggml_backend_t backend, |
| ggml_tensor *tensor, |
| const void *data, size_t offset, |
| size_t size) try { |
| GGML_SYCL_DEBUG("[SYCL] call %s", __func__); |
| GGML_SYCL_DEBUG("%s", debug_get_tensor_str(": tensor", tensor).c_str()); |
| GGML_SYCL_DEBUG(" size=%zu offset=%zu\n", size, offset); |
| ggml_backend_sycl_context * sycl_ctx = (ggml_backend_sycl_context *)backend->context; |
| ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer; |
|
|
| GGML_ASSERT(buf->buft == ggml_backend_sycl_buffer_type(sycl_ctx->device) && "unsupported buffer type"); |
| const queue_ptr stream = sycl_ctx->stream(sycl_ctx->device, 0); |
| SYCL_CHECK(CHECK_TRY_ERROR( |
| (stream)->memcpy((char *)tensor->data + offset, data, size))); |
| } |
| catch (sycl::exception const &exc) { |
| std::cerr << exc.what() << "Exception caught at file:" << __FILE__ |
| << ", line:" << __LINE__ << std::endl; |
| std::exit(1); |
| } |
|
|
| static void ggml_backend_sycl_get_tensor_async(ggml_backend_t backend, |
| const ggml_tensor *tensor, |
| void *data, size_t offset, |
| size_t size) try { |
| GGML_SYCL_DEBUG("[SYCL] call %s", __func__); |
| GGML_SYCL_DEBUG("%s", debug_get_tensor_str(": tensor", tensor).c_str()); |
| GGML_SYCL_DEBUG(" size=%zu offset=%zu\n", size, offset); |
| ggml_backend_sycl_context * sycl_ctx = (ggml_backend_sycl_context *)backend->context; |
| ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer; |
|
|
| GGML_ASSERT(buf->buft == ggml_backend_sycl_buffer_type(sycl_ctx->device) && "unsupported buffer type"); |
| const queue_ptr stream = sycl_ctx->stream(sycl_ctx->device, 0); |
| SYCL_CHECK(CHECK_TRY_ERROR((stream)->memcpy( |
| data, (const char *)tensor->data + offset, size))); |
| } |
| catch (sycl::exception const &exc) { |
| std::cerr << exc.what() << "Exception caught at file:" << __FILE__ |
| << ", line:" << __LINE__ << std::endl; |
| std::exit(1); |
| } |
|
|
| static bool ggml_backend_sycl_cpy_tensor_async(ggml_backend_t backend, |
| const ggml_tensor *src, |
| ggml_tensor *dst) try { |
| ggml_backend_sycl_context * sycl_ctx = (ggml_backend_sycl_context *)backend->context; |
| bool is_cpy_supported = dst->buffer->buft == ggml_backend_sycl_buffer_type(sycl_ctx->device) && |
| ggml_backend_buffer_is_sycl(src->buffer); |
| GGML_SYCL_DEBUG("[SYCL] call %s", __func__); |
| GGML_SYCL_DEBUG("%s", debug_get_tensor_str(": dst", dst).c_str()); |
| GGML_SYCL_DEBUG("%s", debug_get_tensor_str(" src", src).c_str()); |
| GGML_SYCL_DEBUG(" is_cpy_supported=%d\n", is_cpy_supported); |
| if (is_cpy_supported) { |
| |
| |
| |
| |
| |
| const queue_ptr stream = sycl_ctx->stream(sycl_ctx->device, 0); |
| SYCL_CHECK(CHECK_TRY_ERROR((stream)->memcpy( |
| dst->data, src->data, ggml_nbytes(dst)))); |
| return true; |
| } |
|
|
| return false; |
| } |
| catch (sycl::exception const &exc) { |
| std::cerr << exc.what() << "Exception caught at file:" << __FILE__ |
| << ", line:" << __LINE__ << std::endl; |
| std::exit(1); |
| } |
|
|
| static void ggml_backend_sycl_synchronize(ggml_backend_t backend) try { |
| GGML_SYCL_DEBUG("[SYCL] call %s\n", __func__); |
| ggml_backend_sycl_context * sycl_ctx = (ggml_backend_sycl_context *)backend->context; |
| const queue_ptr stream = sycl_ctx->stream(sycl_ctx->device, 0); |
| SYCL_CHECK(CHECK_TRY_ERROR((stream)->wait())); |
|
|
| GGML_UNUSED(backend); |
| } |
| catch (sycl::exception const &exc) { |
| std::cerr << exc.what() << "Exception caught at file:" << __FILE__ |
| << ", line:" << __LINE__ << std::endl; |
| std::exit(1); |
| } |
|
|
| static void ggml_backend_sycl_graph_compute_impl(ggml_backend_sycl_context * sycl_ctx, ggml_cgraph * cgraph) { |
| ggml_sycl_set_main_device(sycl_ctx->device); |
|
|
| for (int i = 0; i < cgraph->n_nodes; i++) { |
| ggml_tensor * node = cgraph->nodes[i]; |
| if (ggml_is_empty(node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) { |
| continue; |
| } |
| if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) { |
| continue; |
| } |
| #ifndef NDEBUG |
| assert(node->buffer->buft == ggml_backend_sycl_buffer_type(sycl_ctx->device)); |
| for (int j = 0; j < GGML_MAX_SRC; j++) { |
| if (node->src[j] != nullptr) { |
| assert(node->src[j]->buffer->buft == ggml_backend_sycl_buffer_type(sycl_ctx->device)); |
| } |
| } |
| #endif |
| bool ok = ggml_sycl_compute_forward(*sycl_ctx, node); |
| if (!ok) { |
| GGML_LOG_ERROR("%s: error: op not supported %s (%s)\n", __func__, node->name, ggml_op_name(node->op)); |
| } |
| GGML_ASSERT(ok); |
| } |
| } |
|
|
| #ifdef GGML_SYCL_GRAPH |
| static bool check_graph_compatibility(ggml_cgraph * cgraph) { |
| if (ggml_sycl_info().device_count > 1) { |
| |
| GGML_LOG_INFO("%s: disabling SYCL graphs due to multiple devices\n", __func__); |
| return false; |
| } |
|
|
| for (int i = 0; i < cgraph->n_nodes; i++) { |
| const ggml_op node_op = cgraph->nodes[i]->op; |
| switch (node_op) { |
| default: |
| break; |
| case GGML_OP_CONCAT: |
| |
| |
| |
| [[fallthrough]]; |
| case GGML_OP_MUL_MAT_ID: |
| |
| |
| |
| GGML_LOG_INFO("%s: disabling SYCL graphs due to unsupported node type %s\n", __func__, |
| ggml_op_name(node_op)); |
| return false; |
| case GGML_OP_MUL_MAT: |
| |
| |
| |
| if (!g_ggml_sycl_use_async_mem_op) { |
| GGML_LOG_INFO( |
| "%s: disabling SYCL graphs due to unsupported node type when using a compiler without the " |
| "oneAPI async memory allocation extension " |
| "%s\n", |
| __func__, ggml_op_name(node_op)); |
| return false; |
| } |
| } |
| } |
| return true; |
| } |
| #endif |
|
|
| static ggml_status ggml_backend_sycl_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) { |
| auto * sycl_ctx = static_cast<ggml_backend_sycl_context *>(backend->context); |
|
|
| #ifdef GGML_SYCL_GRAPH |
| bool use_sycl_graph = !g_ggml_sycl_disable_graph && check_graph_compatibility(cgraph); |
| if (use_sycl_graph) { |
| const bool graph_support = dpct::get_device(sycl_ctx->device).has(sycl::aspect::ext_oneapi_limited_graph); |
| if (!graph_support) { |
| GGML_SYCL_DEBUG("[SYCL-GRAPH] can not use graphs on device:%d\n", sycl_ctx->device); |
| ggml_backend_sycl_graph_compute_impl(sycl_ctx, cgraph); |
| return GGML_STATUS_SUCCESS; |
| } |
|
|
| sycl_ex::command_graph model_sycl_graph(*(sycl_ctx->stream()), {sycl_ex::property::graph::assume_buffer_outlives_graph{}}); |
|
|
| model_sycl_graph.begin_recording(*(sycl_ctx->stream())); |
| ggml_backend_sycl_graph_compute_impl(sycl_ctx, cgraph); |
| model_sycl_graph.end_recording(); |
|
|
| const bool graph_update_support = dpct::get_device(sycl_ctx->device).has(sycl::aspect::ext_oneapi_graph); |
| if (!sycl_ctx->exec_graph || !graph_update_support) { |
| auto exec_graph = graph_update_support ? model_sycl_graph.finalize(sycl_ex::property::graph::updatable{}) : |
| model_sycl_graph.finalize(); |
| sycl_ctx->exec_graph = std::make_unique< |
| sycl_ex::command_graph<sycl_ex::graph_state::executable>>(exec_graph); |
| } else { |
| try { |
| sycl_ctx->exec_graph->update(model_sycl_graph); |
| GGML_SYCL_DEBUG("[SYCL-GRAPH] update success\n"); |
| } catch (sycl::exception const & e) { |
| GGML_SYCL_DEBUG("[SYCL-GRAPH] Exception when updating graph, %s\n", e.what()); |
| auto exec_graph = model_sycl_graph.finalize({sycl_ex::property::graph::updatable{}}); |
| sycl_ctx->exec_graph = std::make_unique< |
| sycl_ex::command_graph<sycl_ex::graph_state::executable>>(exec_graph); |
| } |
| } |
|
|
| sycl_ctx->stream()->ext_oneapi_graph(*(sycl_ctx->exec_graph)); |
| } else |
| #endif |
| { |
| ggml_backend_sycl_graph_compute_impl(sycl_ctx, cgraph); |
| } |
| return GGML_STATUS_SUCCESS; |
| } |
|
|
| static void ggml_backend_sycl_event_record(ggml_backend_t backend, ggml_backend_event_t event) |
| try |
| { |
| ggml_backend_sycl_context *sycl_ctx = |
| (ggml_backend_sycl_context *)backend->context; |
|
|
| sycl::event *sycl_event = static_cast<sycl::event *>(event->context); |
|
|
| const queue_ptr &stream = sycl_ctx->stream(sycl_ctx->device, 0); |
| |
| SYCL_CHECK(CHECK_TRY_ERROR(*sycl_event = stream->ext_oneapi_submit_barrier())); |
| } |
| catch (sycl::exception const &exc) |
| { |
| std::cerr << exc.what() << "Exception caught at file:" << __FILE__ |
| << ", line:" << __LINE__ << std::endl; |
| std::exit(1); |
| } |
|
|
| static void ggml_backend_sycl_event_wait(ggml_backend_t backend, ggml_backend_event_t event) try { |
| GGML_SYCL_DEBUG("[SYCL] call %s\n", __func__); |
| sycl::event* sycl_event = static_cast<sycl::event*>(event->context); |
|
|
| if (ggml_backend_is_sycl(backend)) { |
| SYCL_CHECK(CHECK_TRY_ERROR(sycl_event->wait())); |
| } else |
| GGML_ABORT("fatal error"); |
| } catch (sycl::exception const& exc) { |
| std::cerr << exc.what() << "Exception caught at file:" << __FILE__ |
| << ", line:" << __LINE__ << std::endl; |
| std::exit(1); |
| } |
|
|
| static ggml_backend_i ggml_backend_sycl_interface = { |
| ggml_backend_sycl_get_name, |
| ggml_backend_sycl_free, |
| ggml_backend_sycl_set_tensor_async, |
| ggml_backend_sycl_get_tensor_async, |
| NULL, |
| |
| |
| ggml_backend_sycl_synchronize, |
| NULL, |
| NULL, |
| NULL, |
| NULL, |
| ggml_backend_sycl_graph_compute, |
| ggml_backend_sycl_event_record, |
| ggml_backend_sycl_event_wait, |
| NULL, |
| }; |
|
|
| static ggml_guid_t ggml_backend_sycl_guid() { |
| static ggml_guid guid = { 0x58, 0x05, 0x13, 0x8f, 0xcd, 0x3a, 0x61, 0x9d, 0xe7, 0xcd, 0x98, 0xa9, 0x03, 0xfd, 0x7c, 0x53 }; |
| return &guid; |
| } |
|
|
| bool ggml_backend_is_sycl(ggml_backend_t backend) { |
| return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_sycl_guid()); |
| } |
|
|
| int ggml_backend_sycl_get_device_count() { |
| return ggml_sycl_info().device_count; |
| } |
|
|
|
|
| |
|
|
| struct ggml_backend_sycl_device_context { |
| int device; |
| std::string name; |
| std::string description; |
| int op_offload_min_batch_size; |
| }; |
|
|
| static const char * ggml_backend_sycl_device_get_name(ggml_backend_dev_t dev) { |
| ggml_backend_sycl_device_context * ctx = (ggml_backend_sycl_device_context *)dev->context; |
| return ctx->name.c_str(); |
| } |
|
|
| static const char * ggml_backend_sycl_device_get_description(ggml_backend_dev_t dev) { |
| ggml_backend_sycl_device_context * ctx = (ggml_backend_sycl_device_context *)dev->context; |
| return ctx->description.c_str(); |
| } |
|
|
| static void ggml_backend_sycl_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) { |
| ggml_backend_sycl_device_context * ctx = (ggml_backend_sycl_device_context *)dev->context; |
| ggml_sycl_set_device(ctx->device); |
| SYCL_CHECK(CHECK_TRY_ERROR( |
| dpct::dev_mgr::instance().get_device(ctx->device).get_memory_info(*free, *total))); |
| } |
|
|
| static enum ggml_backend_dev_type ggml_backend_sycl_device_get_type(ggml_backend_dev_t dev) { |
| GGML_UNUSED(dev); |
| return GGML_BACKEND_DEVICE_TYPE_GPU; |
| } |
|
|
| static void ggml_backend_sycl_device_get_props(ggml_backend_dev_t dev, ggml_backend_dev_props * props) { |
| props->name = ggml_backend_sycl_device_get_name(dev); |
| props->description = ggml_backend_sycl_device_get_description(dev); |
| props->type = ggml_backend_sycl_device_get_type(dev); |
| ggml_backend_sycl_device_get_memory(dev, &props->memory_free, &props->memory_total); |
|
|
| bool host_buffer = getenv("GGML_SYCL_NO_PINNED") == nullptr; |
| #ifdef GGML_SYCL_NO_PEER_COPY |
| bool events = false; |
| #else |
| bool events = true; |
| #endif |
|
|
| props->caps = { |
| true, |
| host_buffer, |
| false, |
| events, |
| }; |
| } |
|
|
| static ggml_backend_t ggml_backend_sycl_device_init(ggml_backend_dev_t dev, const char * params) { |
| GGML_UNUSED(params); |
| ggml_backend_sycl_device_context * ctx = (ggml_backend_sycl_device_context *)dev->context; |
| return ggml_backend_sycl_init(ctx->device); |
| } |
|
|
| static ggml_backend_buffer_type_t ggml_backend_sycl_device_get_buffer_type(ggml_backend_dev_t dev) { |
| ggml_backend_sycl_device_context * ctx = (ggml_backend_sycl_device_context *)dev->context; |
| return ggml_backend_sycl_buffer_type(ctx->device); |
| } |
|
|
| static ggml_backend_buffer_type_t ggml_backend_sycl_device_get_host_buffer_type(ggml_backend_dev_t dev) { |
| GGML_UNUSED(dev); |
| return ggml_backend_sycl_host_buffer_type(); |
| } |
|
|
| static ggml_backend_buffer_t ggml_backend_sycl_device_buffer_from_host_ptr(ggml_backend_dev_t dev, void * ptr, size_t size, size_t max_tensor_size) { |
| GGML_UNUSED(dev); |
| GGML_UNUSED(ptr); |
| GGML_UNUSED(size); |
| GGML_UNUSED(max_tensor_size); |
| return nullptr; |
| } |
|
|
| static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const ggml_tensor * op) { |
| ggml_backend_sycl_device_context *sycl_ctx = |
| (ggml_backend_sycl_device_context *)dev->context; |
| int device = sycl_ctx->device; |
| switch (op->op) { |
| case GGML_OP_CONV_TRANSPOSE_1D: |
| { |
| ggml_type src0_type = op->src[0]->type; |
| ggml_type src1_type = op->src[1]->type; |
| if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F32) { |
| return true; |
| } |
| return false; |
| } |
| case GGML_OP_UNARY: |
| switch (ggml_get_unary_op(op)) { |
| case GGML_UNARY_OP_SGN: |
| case GGML_UNARY_OP_ABS: |
| case GGML_UNARY_OP_NEG: |
| case GGML_UNARY_OP_STEP: |
| case GGML_UNARY_OP_RELU: |
| case GGML_UNARY_OP_HARDSIGMOID: |
| case GGML_UNARY_OP_TANH: |
| case GGML_UNARY_OP_GELU: |
| case GGML_UNARY_OP_SILU: |
| case GGML_UNARY_OP_SIGMOID: |
| case GGML_UNARY_OP_HARDSWISH: |
| case GGML_UNARY_OP_GELU_QUICK: |
| case GGML_UNARY_OP_GELU_ERF: |
| case GGML_UNARY_OP_EXP: |
| case GGML_UNARY_OP_SOFTPLUS: |
| case GGML_UNARY_OP_ELU: |
| case GGML_UNARY_OP_CEIL: |
| return true; |
| case GGML_UNARY_OP_FLOOR: |
| case GGML_UNARY_OP_ROUND: |
| case GGML_UNARY_OP_TRUNC: |
| #if defined (GGML_SYCL_F16) |
| return ggml_is_contiguous(op->src[0]) && (op->type == op->src[0]->type); |
| #else |
| return ggml_is_contiguous(op->src[0]) && (op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32) && (op->type == op->src[0]->type); |
| #endif |
| default: |
| return false; |
| } |
| case GGML_OP_GLU: |
| switch (ggml_get_glu_op(op)) { |
| case GGML_GLU_OP_REGLU: |
| case GGML_GLU_OP_GEGLU: |
| case GGML_GLU_OP_SWIGLU: |
| case GGML_GLU_OP_SWIGLU_OAI: |
| case GGML_GLU_OP_GEGLU_ERF: |
| case GGML_GLU_OP_GEGLU_QUICK: |
| return ggml_is_contiguous_1(op->src[0]); |
| default: |
| return false; |
| } |
| break; |
| case GGML_OP_MUL_MAT: |
| case GGML_OP_MUL_MAT_ID: |
| { |
| struct ggml_tensor * a = op->src[0]; |
| struct ggml_tensor * b = op->src[1]; |
|
|
| if (a->ne[3] != b->ne[3]) { |
| return false; |
| } |
| ggml_type a_type = a->type; |
| if (a_type == GGML_TYPE_IQ4_NL || a_type == GGML_TYPE_IQ4_XS || |
| a_type == GGML_TYPE_IQ3_XXS || a_type == GGML_TYPE_IQ3_S || |
| a_type == GGML_TYPE_IQ2_XXS || a_type == GGML_TYPE_IQ2_XS || a_type == GGML_TYPE_IQ2_S || |
| a_type == GGML_TYPE_IQ1_S || a_type == GGML_TYPE_IQ1_M |
| ) { |
| if (b->ne[1] == 1 && ggml_nrows(b) > 1) { |
| return false; |
| } |
| } |
| ggml_type src0_type = op->src[0]->type; |
| if (src0_type == GGML_TYPE_BF16 ) { |
| |
| |
| return false; |
| } |
|
|
| |
| if (ggml_is_permuted(a) && !ggml_is_contiguous(a) && |
| a->ne[2] > 1 && a->ne[3] > 1 && src0_type == GGML_TYPE_F16) { |
| return false; |
| } |
|
|
| |
| if (!ggml_is_permuted(a) && ggml_is_permuted(b) && b->ne[2] > 1 && b->ne[3] > 1 && |
| a->ne[0] > 128 && a->ne[2] == 1 && src0_type == GGML_TYPE_F16) { |
| return false; |
| } |
| return true; |
| } |
| case GGML_OP_OUT_PROD: |
| return op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32 && op->ne[2] == 1 && op->ne[3] == 1; |
| case GGML_OP_GET_ROWS: |
| { |
| switch (op->src[0]->type) { |
| case GGML_TYPE_F16: |
| case GGML_TYPE_F32: |
| case GGML_TYPE_Q4_0: |
| case GGML_TYPE_Q4_1: |
| case GGML_TYPE_Q5_0: |
| case GGML_TYPE_Q5_1: |
| case GGML_TYPE_Q8_0: |
| return true; |
| default: |
| return false; |
| } |
| } |
| case GGML_OP_SET: |
| return (op->type == GGML_TYPE_F32) && |
| (op->src[0] && op->src[1]) && |
| (op->src[0]->type == GGML_TYPE_F32) && |
| (op->src[1]->type == GGML_TYPE_F32); |
|
|
| case GGML_OP_SET_ROWS: |
| { |
| return ((op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16 || op->type == GGML_TYPE_BF16 || |
| op->type == GGML_TYPE_Q8_0 || op->type == GGML_TYPE_Q5_1 || op->type == GGML_TYPE_Q5_0 || |
| op->type == GGML_TYPE_Q4_1 || op->type == GGML_TYPE_Q4_0 || op->type == GGML_TYPE_IQ4_NL) && |
| (op->src[1]->type == GGML_TYPE_I64 || op->src[1]->type == GGML_TYPE_I32)); |
| } |
| break; |
| case GGML_OP_CPY: |
| { |
| ggml_type src0_type = op->src[0]->type; |
| ggml_type src1_type = op->src[1]->type; |
| if (src0_type == src1_type && (ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1])) && src0_type != GGML_TYPE_BF16) { |
| return true; |
| } |
| if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F32) { |
| return true; |
| } |
| if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F16) { |
| return true; |
| } |
| if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q8_0) { |
| return true; |
| } |
| if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q4_0) { |
| return true; |
| } |
| if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q4_1) { |
| return true; |
| } |
| if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F16) { |
| return true; |
| } |
| if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) { |
| return true; |
| } |
| if (src0_type == GGML_TYPE_Q8_0 && src1_type == GGML_TYPE_F32) { |
| return true; |
| } |
| if (src0_type == GGML_TYPE_Q4_0 && src1_type == GGML_TYPE_F32) { |
| return true; |
| } |
| if (src0_type == GGML_TYPE_Q4_1 && src1_type == GGML_TYPE_F32) { |
| return true; |
| } |
| if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q5_0) { |
| return true; |
| } |
| if (src0_type == GGML_TYPE_Q5_0 && src1_type == GGML_TYPE_F32) { |
| return true; |
| } |
| if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q5_1) { |
| return true; |
| } |
| if (src0_type == GGML_TYPE_Q5_1 && src1_type == GGML_TYPE_F32) { |
| return true; |
| } |
| if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_IQ4_NL) { |
| return true; |
| } |
| if(src0_type == GGML_TYPE_Q8_0 && src1_type == GGML_TYPE_Q8_0) { |
| return true; |
| } |
| if(src0_type == GGML_TYPE_Q5_0 && src1_type == GGML_TYPE_Q5_0) { |
| return true; |
| } |
| if(src0_type == GGML_TYPE_Q5_1 && src1_type == GGML_TYPE_Q5_1) { |
| return true; |
| } |
| if(src0_type == GGML_TYPE_Q4_0 && src1_type == GGML_TYPE_Q4_0) { |
| return true; |
| } |
| if(src0_type == GGML_TYPE_Q4_1 && src1_type == GGML_TYPE_Q4_1) { |
| return true; |
| } |
| return false; |
| } |
| case GGML_OP_REPEAT_BACK: |
| { |
| ggml_type src0_type = op->src[0]->type; |
| return src0_type == GGML_TYPE_F32; |
| } |
| case GGML_OP_CONCAT: |
| case GGML_OP_DUP: |
| case GGML_OP_ARGMAX: |
| case GGML_OP_NONE: |
| case GGML_OP_RESHAPE: |
| case GGML_OP_VIEW: |
| case GGML_OP_PERMUTE: |
| case GGML_OP_TRANSPOSE: |
| case GGML_OP_ADD: |
| case GGML_OP_ADD1: |
| case GGML_OP_ADD_ID: |
| case GGML_OP_SUB: |
| case GGML_OP_COUNT_EQUAL: |
| case GGML_OP_MUL: |
| case GGML_OP_DIV: |
| case GGML_OP_REPEAT: |
| return true; |
| case GGML_OP_PAD_REFLECT_1D: |
| return ggml_is_contiguous(op->src[0]) && op-> type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32; |
| case GGML_OP_SQR: |
| case GGML_OP_SQRT: |
| case GGML_OP_SIN: |
| case GGML_OP_COS: |
| case GGML_OP_CLAMP: |
| case GGML_OP_LOG: |
| #if defined (GGML_SYCL_F16) |
| return ((op->type == GGML_TYPE_F32 || op->type == GGML_SYCL_F16) && (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_SYCL_F16) && (op->type == op->src[0]->type)); |
| #else |
| return (op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32) && (op->type == op->src[0]->type); |
| #endif |
| case GGML_OP_NORM: |
| case GGML_OP_L2_NORM: |
| case GGML_OP_GROUP_NORM: |
| case GGML_OP_RMS_NORM: |
| return true; |
| case GGML_OP_RMS_NORM_BACK: |
| return ggml_is_contiguous(op->src[0]); |
| case GGML_OP_SCALE: |
| return true; |
| case GGML_OP_CONT: |
| return op->src[0]->type != GGML_TYPE_BF16; |
| case GGML_OP_TRI: |
| { |
| const ggml_tensor * src0 = op->src[0]; |
| return src0 && |
| op->type == GGML_TYPE_F32 && |
| ggml_is_contiguous(src0); |
| } |
| case GGML_OP_DIAG_MASK_INF: |
| return true; |
| case GGML_OP_SOFT_MAX: |
| return true; |
| case GGML_OP_SOFT_MAX_BACK: { |
| float max_bias = 0.0f; |
| memcpy(&max_bias, (const float *) op->op_params + 1, sizeof(float)); |
| return max_bias == 0.0f; |
| } |
| case GGML_OP_ROPE: |
| case GGML_OP_ROPE_BACK: |
| case GGML_OP_IM2COL: |
| case GGML_OP_UPSCALE: |
| return true; |
| case GGML_OP_SUM: |
| case GGML_OP_SUM_ROWS: |
| case GGML_OP_MEAN: |
| return ggml_is_contiguous(op->src[0]); |
| case GGML_OP_ARGSORT: |
| return op->src[0]->ne[0] * sizeof(int) <= |
| ggml_sycl_info().devices[device].smpbo; |
| case GGML_OP_TOP_K: { |
| const ggml_tensor * src0 = op->src[0]; |
| const int k = op->ne[0]; |
| return src0 && |
| op->type == GGML_TYPE_I32 && |
| src0->type == GGML_TYPE_F32 && |
| ggml_is_contiguous(src0) && |
| k > 0 && k <= 32; |
| } |
| case GGML_OP_POOL_2D: |
| return true; |
| case GGML_OP_ACC: |
| return ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1]); |
| case GGML_OP_PAD: |
| |
| if (ggml_get_op_params_i32(op, 8) != 0) { |
| return false; |
| } |
| return ggml_is_contiguous(op->src[0]); |
| case GGML_OP_LEAKY_RELU: |
| case GGML_OP_TIMESTEP_EMBEDDING: |
| case GGML_OP_RWKV_WKV6: |
| case GGML_OP_RWKV_WKV7: |
| case GGML_OP_GATED_LINEAR_ATTN: |
| case GGML_OP_GATED_DELTA_NET: |
| return true; |
| case GGML_OP_SSM_CONV: |
| return op->type == GGML_TYPE_F32 && |
| op->src[0]->type == GGML_TYPE_F32 && |
| op->src[1]->type == GGML_TYPE_F32; |
| case GGML_OP_ROLL: |
| return op->type == GGML_TYPE_F32; |
| case GGML_OP_ARANGE: |
| return op->type == GGML_TYPE_F32; |
| case GGML_OP_FLASH_ATTN_EXT: |
| return ggml_sycl_flash_attn_ext_supported(device, op); |
| default: |
| return false; |
| } |
|
|
| GGML_UNUSED(dev); |
| } |
|
|
| static bool ggml_backend_sycl_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) { |
| if (buft->iface.get_name != ggml_backend_sycl_buffer_type_get_name) { |
| return false; |
| } |
| ggml_backend_sycl_buffer_type_context * buft_ctx = (ggml_backend_sycl_buffer_type_context *)buft->context; |
| ggml_backend_sycl_device_context * sycl_ctx = (ggml_backend_sycl_device_context *)dev->context; |
| return buft_ctx->device == sycl_ctx->device; |
| } |
|
|
| static int64_t get_op_batch_size(const ggml_tensor * op) { |
| switch (op->op) { |
| case GGML_OP_GET_ROWS: |
| return 0; |
| case GGML_OP_MUL_MAT: |
| return op->ne[1]; |
| case GGML_OP_MUL_MAT_ID: |
| case GGML_OP_ROPE: |
| return op->ne[2]; |
| default: |
| return ggml_nrows(op); |
| } |
| } |
|
|
| static bool ggml_backend_sycl_device_offload_op(ggml_backend_dev_t dev, const ggml_tensor * op) { |
| ggml_backend_sycl_device_context * sycl_ctx = (ggml_backend_sycl_device_context *)dev->context; |
| return get_op_batch_size(op) >= sycl_ctx->op_offload_min_batch_size; |
| } |
|
|
| static ggml_backend_event_t |
| ggml_backend_sycl_device_event_new(ggml_backend_dev_t dev) { |
|
|
| #ifdef GGML_SYCL_NO_PEER_COPY |
| return nullptr; |
| #else |
| sycl::event *event_ptr = new sycl::event(); |
|
|
| return new ggml_backend_event{ |
| dev, |
| event_ptr, |
| }; |
| #endif |
| } |
|
|
| static void ggml_backend_sycl_device_event_free(ggml_backend_dev_t dev, ggml_backend_event_t event) try { |
| GGML_UNUSED(dev); |
| if (event == nullptr) { |
| return; |
| } |
|
|
| if (event->context != nullptr) { |
| sycl::event *sycl_event = static_cast<sycl::event *>(event->context); |
| delete sycl_event; |
| event->context = nullptr; |
| } |
|
|
| delete event; |
| } catch (sycl::exception const &exc) { |
| std::cerr << exc.what() << "Exception caught at file:" << __FILE__ |
| << ", line:" << __LINE__ << std::endl; |
| std::exit(1); |
| } |
|
|
|
|
| static void ggml_backend_sycl_device_event_synchronize(ggml_backend_dev_t dev, ggml_backend_event_t event) try { |
| GGML_UNUSED(dev); |
| GGML_SYCL_DEBUG("[SYCL] call %s\n", __func__); |
|
|
| sycl::event *sycl_event = static_cast<sycl::event *>(event->context); |
| SYCL_CHECK(CHECK_TRY_ERROR(sycl_event->wait())); |
| } catch (sycl::exception const &exc) { |
| std::cerr << exc.what() << "Exception caught at file:" << __FILE__ |
| << ", line:" << __LINE__ << std::endl; |
| std::exit(1); |
| } |
|
|
| static const ggml_backend_device_i ggml_backend_sycl_device_interface = { |
| ggml_backend_sycl_device_get_name, |
| ggml_backend_sycl_device_get_description, |
| ggml_backend_sycl_device_get_memory, |
| ggml_backend_sycl_device_get_type, |
| ggml_backend_sycl_device_get_props, |
| ggml_backend_sycl_device_init, |
| ggml_backend_sycl_device_get_buffer_type, |
| ggml_backend_sycl_device_get_host_buffer_type, |
| ggml_backend_sycl_device_buffer_from_host_ptr, |
| ggml_backend_sycl_device_supports_op, |
| ggml_backend_sycl_device_supports_buft, |
| ggml_backend_sycl_device_offload_op, |
| ggml_backend_sycl_device_event_new, |
| ggml_backend_sycl_device_event_free, |
| ggml_backend_sycl_device_event_synchronize, |
| }; |
|
|
| |
|
|
| struct ggml_backend_sycl_reg_context { |
| std::vector<ggml_backend_dev_t> devices; |
| }; |
|
|
| static const char * ggml_backend_sycl_reg_get_name(ggml_backend_reg_t reg) { |
| GGML_UNUSED(reg); |
| return GGML_SYCL_NAME; |
| } |
|
|
| static size_t ggml_backend_sycl_reg_get_device_count(ggml_backend_reg_t reg) { |
| ggml_backend_sycl_reg_context * ctx = (ggml_backend_sycl_reg_context *)reg->context; |
| return ctx->devices.size(); |
| } |
|
|
| static ggml_backend_dev_t ggml_backend_sycl_reg_get_device(ggml_backend_reg_t reg, size_t index) { |
| ggml_backend_sycl_reg_context * ctx = (ggml_backend_sycl_reg_context *)reg->context; |
| GGML_ASSERT(index < ctx->devices.size()); |
| return ctx->devices[index]; |
| } |
|
|
| static void *ggml_backend_sycl_reg_get_proc_address(ggml_backend_reg_t reg, const char *name) { |
| GGML_UNUSED(reg); |
|
|
| if (strcmp(name, "ggml_backend_split_buffer_type") == 0) { |
| return (void *)ggml_backend_sycl_split_buffer_type; |
| } |
|
|
| |
| |
| |
| GGML_UNUSED(name); |
| return nullptr; |
| } |
|
|
| static const ggml_backend_reg_i ggml_backend_sycl_reg_interface = { |
| ggml_backend_sycl_reg_get_name, |
| ggml_backend_sycl_reg_get_device_count, |
| ggml_backend_sycl_reg_get_device, |
| ggml_backend_sycl_reg_get_proc_address, |
| }; |
|
|
|
|
| |
|
|
| ggml_backend_reg_t ggml_backend_sycl_reg() { |
| static ggml_backend_reg reg; |
| static bool initialized = false; |
|
|
| { |
| static std::mutex mutex; |
| std::lock_guard<std::mutex> lock(mutex); |
| if (!initialized) { |
| ggml_backend_sycl_reg_context * ctx = new ggml_backend_sycl_reg_context; |
| const int min_batch_size = getenv("GGML_OP_OFFLOAD_MIN_BATCH") ? atoi(getenv("GGML_OP_OFFLOAD_MIN_BATCH")) : 32; |
|
|
| for (int i = 0; i < ggml_sycl_info().device_count; i++) { |
| ggml_backend_sycl_device_context * dev_ctx = new ggml_backend_sycl_device_context; |
| dev_ctx->device = i; |
| dev_ctx->name = GGML_SYCL_NAME + std::to_string(i); |
|
|
| ggml_sycl_set_device(i); |
|
|
| dpct::device_info prop; |
| SYCL_CHECK(CHECK_TRY_ERROR(dpct::get_device_info( |
| prop, dpct::dev_mgr::instance().get_device(i)))); |
|
|
| dev_ctx->description = prop.get_name(); |
| dev_ctx->op_offload_min_batch_size = min_batch_size; |
|
|
| ggml_backend_dev_t dev = new ggml_backend_device { |
| ggml_backend_sycl_device_interface, |
| ®, |
| dev_ctx |
| }; |
| ctx->devices.push_back(dev); |
| } |
|
|
| reg = ggml_backend_reg { |
| GGML_BACKEND_API_VERSION, |
| ggml_backend_sycl_reg_interface, |
| ctx |
| }; |
| } |
|
|
| initialized = true; |
| } |
|
|
| return ® |
| } |
|
|
| ggml_backend_t ggml_backend_sycl_init(int device) { |
| GGML_SYCL_DEBUG("[SYCL] call ggml_backend_sycl_init\n"); |
| ggml_check_sycl(); |
|
|
| check_allow_gpu_index(device); |
|
|
| ggml_backend_sycl_context * ctx = new ggml_backend_sycl_context(device); |
| if (ctx == nullptr) { |
| GGML_LOG_ERROR("%s: error: failed to allocate context\n", __func__); |
| return nullptr; |
| }; |
|
|
| ggml_backend_t sycl_backend = new ggml_backend { |
| ggml_backend_sycl_guid(), |
| ggml_backend_sycl_interface, |
| ggml_backend_reg_dev_get(ggml_backend_sycl_reg(), device), |
| ctx |
| }; |
|
|
| return sycl_backend; |
| } |
|
|
| GGML_BACKEND_DL_IMPL(ggml_backend_sycl_reg) |
|
|