Arrcttacsrks commited on
Commit
0ad6c07
·
verified ·
1 Parent(s): 467068d

Upload llama.cpp/ggml/src/ggml-kompute.cpp with huggingface_hub

Browse files
Files changed (1) hide show
  1. llama.cpp/ggml/src/ggml-kompute.cpp +2184 -0
llama.cpp/ggml/src/ggml-kompute.cpp ADDED
@@ -0,0 +1,2184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "ggml-impl.h"
2
+ #include "ggml-backend.h"
3
+ #include "ggml-backend-impl.h"
4
+ #include "ggml-kompute.h"
5
+
6
+ // These are generated at build time by cmake custom command
7
+ #include "shaderop_scale.h"
8
+ #include "shaderop_scale_8.h"
9
+ #include "shaderop_add.h"
10
+ #include "shaderop_addrow.h"
11
+ #include "shaderop_mul.h"
12
+ #include "shaderop_silu.h"
13
+ #include "shaderop_relu.h"
14
+ #include "shaderop_gelu.h"
15
+ #include "shaderop_softmax.h"
16
+ #include "shaderop_norm.h"
17
+ #include "shaderop_rmsnorm.h"
18
+ #include "shaderop_diagmask.h"
19
+ #include "shaderop_mul_mat_f16.h"
20
+ #include "shaderop_mul_mat_q8_0.h"
21
+ #include "shaderop_mul_mat_q4_0.h"
22
+ #include "shaderop_mul_mat_q4_1.h"
23
+ #include "shaderop_mul_mat_q4_k.h"
24
+ #include "shaderop_mul_mat_q6_k.h"
25
+ #include "shaderop_mul_mat_mat_f32.h"
26
+ #include "shaderop_getrows_f32.h"
27
+ #include "shaderop_getrows_f16.h"
28
+ #include "shaderop_getrows_q4_0.h"
29
+ #include "shaderop_getrows_q4_1.h"
30
+ #include "shaderop_getrows_q6_k.h"
31
+ #include "shaderop_rope_f16.h"
32
+ #include "shaderop_rope_f32.h"
33
+ #include "shaderop_cpy_f16_f16.h"
34
+ #include "shaderop_cpy_f16_f32.h"
35
+ #include "shaderop_cpy_f32_f16.h"
36
+ #include "shaderop_cpy_f32_f32.h"
37
+
38
+ #include <algorithm>
39
+ #include <array>
40
+ #include <cassert>
41
+ #include <cstdint>
42
+ #include <cstdio>
43
+ #include <cstring>
44
+ #include <iostream>
45
+ #include <memory>
46
+ #include <mutex>
47
+ #include <stdexcept>
48
+ #include <string>
49
+ #include <unordered_map>
50
+ #include <utility>
51
+ #include <vector>
52
+
53
+ #include <kompute/Kompute.hpp>
54
+ #include <vulkan/vulkan.hpp>
55
+
56
+ #ifdef __linux__
57
+ #include <cstdlib> // for setenv
58
+ #endif
59
+
60
+ #define QK4_0 32
61
+ #define QR4_0 2
62
+ #define QK4_1 32
63
+ #define QK_NL 16
64
+
65
+ typedef ggml_fp16_t half;
66
+
67
+ static std::string ggml_kompute_format_name(int device) {
68
+ return "Kompute" + std::to_string(device);
69
+ }
70
+
71
+ struct ggml_kompute_context {
72
+ int device;
73
+ std::string name;
74
+ std::shared_ptr<vk::DescriptorPool> pool;
75
+
76
+ ggml_kompute_context(int device)
77
+ : device(device), name(ggml_kompute_format_name(device)) {}
78
+ };
79
+
80
+ // FIXME: It would be good to consolidate the kompute manager and the kompute context into one object
81
+ // and consolidate the init functions and simplify object lifetime management. As it currently stands,
82
+ // we *have* to have the kompute manager no matter what for device discovery, but the kompute context
83
+ // is only created when a device is set and vulkan is explicitly turned on.
84
+ static ggml_kompute_context *s_kompute_context = nullptr;
85
+
86
+ class kompute_manager {
87
+ kp::Manager *s_mgr = nullptr;
88
+
89
+ public:
90
+ kp::Manager *operator()() {
91
+ if (s_mgr && !s_mgr->hasInstance()) {
92
+ destroy();
93
+ }
94
+ if (!s_mgr) {
95
+ s_mgr = new kp::Manager;
96
+ }
97
+ return s_mgr;
98
+ }
99
+
100
+ void destroy() {
101
+ delete s_mgr;
102
+ s_mgr = nullptr;
103
+ }
104
+ };
105
+
106
+ static kompute_manager komputeManager;
107
+
108
+ struct ggml_vk_memory {
109
+ void *data = nullptr;
110
+ size_t size = 0;
111
+ vk::DeviceMemory *primaryMemory = nullptr;
112
+ vk::Buffer *primaryBuffer = nullptr;
113
+ vk::DeviceMemory *stagingMemory = nullptr;
114
+ vk::Buffer *stagingBuffer = nullptr;
115
+ };
116
+
117
+ #ifdef __linux__
118
+ __attribute__((constructor))
119
+ static void enable_sam() {
120
+ setenv("RADV_PERFTEST", "sam", false);
121
+ }
122
+ #endif
123
+
124
+ static bool ggml_vk_checkPhysicalDeviceFeatures(vk::PhysicalDevice physical_device) {
125
+ vk::PhysicalDeviceFeatures availableFeatures;
126
+ physical_device.getFeatures(&availableFeatures);
127
+
128
+ if (!availableFeatures.shaderInt16)
129
+ return false;
130
+
131
+ vk::PhysicalDeviceVulkan11Features availableFeatures11;
132
+ vk::PhysicalDeviceVulkan12Features availableFeatures12;
133
+
134
+ availableFeatures11.pNext = &availableFeatures12;
135
+ availableFeatures12.pNext = nullptr;
136
+
137
+ vk::PhysicalDeviceFeatures2 features2;
138
+ features2.pNext = &availableFeatures11;
139
+
140
+ physical_device.getFeatures2(&features2);
141
+
142
+ if (!availableFeatures11.uniformAndStorageBuffer16BitAccess ||
143
+ !availableFeatures11.storageBuffer16BitAccess) {
144
+ return false;
145
+ }
146
+
147
+ if (!availableFeatures12.storageBuffer8BitAccess ||
148
+ !availableFeatures12.uniformAndStorageBuffer8BitAccess ||
149
+ !availableFeatures12.shaderFloat16 ||
150
+ !availableFeatures12.shaderInt8) {
151
+ return false;
152
+ }
153
+
154
+ return true;
155
+ }
156
+
157
+ static const char * ggml_vk_getVendorName(uint32_t vendorID) {
158
+ switch (vendorID) {
159
+ case 0x10DE:
160
+ return "nvidia";
161
+ case 0x1002:
162
+ return "amd";
163
+ case 0x8086:
164
+ return "intel";
165
+ default:
166
+ return "unknown";
167
+ }
168
+ }
169
+
170
+ static std::vector<ggml_vk_device> ggml_vk_available_devices_internal(size_t memoryRequired) {
171
+ std::vector<ggml_vk_device> results;
172
+ if (!komputeManager()->hasVulkan() || !komputeManager()->hasInstance())
173
+ return results;
174
+
175
+ std::vector<vk::PhysicalDevice> physical_devices;
176
+ try {
177
+ physical_devices = komputeManager()->listDevices();
178
+ } catch (vk::SystemError & err) {
179
+ std::cerr << __func__ << ": ignoring Vulkan exception: " << err.what() << "\n";
180
+ return results;
181
+ }
182
+
183
+ uint32_t deviceCount = physical_devices.size();
184
+ if (deviceCount == 0)
185
+ return results;
186
+
187
+ std::unordered_map<std::string, size_t> count_by_name;
188
+
189
+ for (uint32_t i = 0; i < deviceCount; i++) {
190
+ const auto & physical_device = physical_devices[i];
191
+
192
+ VkPhysicalDeviceProperties dev_props = physical_device.getProperties();
193
+ VkPhysicalDeviceMemoryProperties memoryProperties = physical_device.getMemoryProperties();
194
+ const uint32_t major = VK_VERSION_MAJOR(dev_props.apiVersion);
195
+ const uint32_t minor = VK_VERSION_MINOR(dev_props.apiVersion);
196
+ if (major < 1 || minor < 2)
197
+ continue;
198
+
199
+ if (!ggml_vk_checkPhysicalDeviceFeatures(physical_device))
200
+ continue;
201
+
202
+ size_t heapSize = 0;
203
+ for (uint32_t j = 0; j < memoryProperties.memoryHeapCount; ++j) {
204
+ VkMemoryHeap heap = memoryProperties.memoryHeaps[j];
205
+ if (heap.flags & VK_MEMORY_HEAP_DEVICE_LOCAL_BIT) {
206
+ heapSize = heap.size;
207
+ break;
208
+ }
209
+ }
210
+
211
+ if (heapSize < memoryRequired)
212
+ continue;
213
+
214
+ auto ext_props = physical_device.enumerateDeviceExtensionProperties();
215
+ bool has_maintenance4 = false;
216
+
217
+ // Check if maintenance4 is supported
218
+ for (const auto & properties : ext_props) {
219
+ if (strcmp("VK_KHR_maintenance4", properties.extensionName) == 0) {
220
+ has_maintenance4 = true;
221
+ }
222
+ }
223
+
224
+ vk::PhysicalDeviceSubgroupProperties subgroup_props;
225
+ vk::PhysicalDeviceProperties2 dev_props2;
226
+ vk::PhysicalDeviceMaintenance3Properties dev_props3;
227
+ vk::PhysicalDeviceMaintenance4Properties dev_props4;
228
+ dev_props2.pNext = &dev_props3;
229
+ dev_props3.pNext = &subgroup_props;
230
+ if (has_maintenance4) {
231
+ subgroup_props.pNext = &dev_props4;
232
+ }
233
+ physical_device.getProperties2(&dev_props2);
234
+
235
+ if (subgroup_props.subgroupSize < 32)
236
+ continue;
237
+
238
+ ggml_vk_device d;
239
+ d.index = i;
240
+ d.type = dev_props.deviceType;
241
+ d.heapSize = heapSize;
242
+ d.vendor = strdup(ggml_vk_getVendorName(dev_props.vendorID));
243
+ d.subgroupSize = subgroup_props.subgroupSize;
244
+ d.bufferAlignment = dev_props.limits.minStorageBufferOffsetAlignment;
245
+
246
+ if (has_maintenance4) {
247
+ d.maxAlloc = std::min(dev_props3.maxMemoryAllocationSize, dev_props4.maxBufferSize);
248
+ } else {
249
+ d.maxAlloc = dev_props3.maxMemoryAllocationSize;
250
+ }
251
+
252
+ std::string name(dev_props.deviceName);
253
+ size_t n_idx = ++count_by_name[name];
254
+ if (n_idx > 1) {
255
+ name += " (" + std::to_string(n_idx) + ")";
256
+ }
257
+ d.name = strdup(name.c_str());
258
+
259
+ results.push_back(d);
260
+ }
261
+
262
+ std::stable_sort(results.begin(), results.end(),
263
+ [](const ggml_vk_device& lhs, const ggml_vk_device& rhs) -> bool {
264
+ if (lhs.type != rhs.type) {
265
+ if (lhs.type == VK_PHYSICAL_DEVICE_TYPE_DISCRETE_GPU) return true;
266
+ if (rhs.type == VK_PHYSICAL_DEVICE_TYPE_DISCRETE_GPU) return false;
267
+
268
+ if (lhs.type == VK_PHYSICAL_DEVICE_TYPE_INTEGRATED_GPU) return true;
269
+ if (rhs.type == VK_PHYSICAL_DEVICE_TYPE_INTEGRATED_GPU) return false;
270
+ }
271
+ return lhs.heapSize < rhs.heapSize;
272
+ }
273
+ );
274
+
275
+ return results;
276
+ }
277
+
278
+ static std::vector<ggml_vk_device>& ggml_vk_available_devices() {
279
+ static std::vector<ggml_vk_device> devices = ggml_vk_available_devices_internal(0);
280
+ return devices;
281
+ }
282
+
283
+ static void ggml_vk_filterByVendor(std::vector<ggml_vk_device>& devices, const std::string& targetVendor) {
284
+ devices.erase(
285
+ std::remove_if(devices.begin(), devices.end(),
286
+ [&targetVendor](const ggml_vk_device& device) {
287
+ return device.vendor != targetVendor;
288
+ }),
289
+ devices.end()
290
+ );
291
+ }
292
+
293
+ static void ggml_vk_filterByName(std::vector<ggml_vk_device>& devices, const std::string& targetName) {
294
+ devices.erase(
295
+ std::remove_if(devices.begin(), devices.end(),
296
+ [&targetName](const ggml_vk_device& device) {
297
+ return device.name != targetName;
298
+ }),
299
+ devices.end()
300
+ );
301
+ }
302
+
303
+ static bool ggml_vk_get_device(ggml_vk_device * device, size_t memoryRequired, const std::string & name) {
304
+ if (name.empty())
305
+ return false;
306
+
307
+ auto devices = ggml_vk_available_devices_internal(memoryRequired);
308
+ if (name == "amd" || name == "nvidia" || name == "intel") {
309
+ ggml_vk_filterByVendor(devices, name);
310
+ } else if (name != "gpu") {
311
+ ggml_vk_filterByName(devices, name);
312
+ }
313
+
314
+ if (devices.empty())
315
+ return false;
316
+
317
+ *device = devices.front();
318
+ return true;
319
+ }
320
+
321
+ bool ggml_vk_get_device(ggml_vk_device * device, size_t memoryRequired, const char * name) {
322
+ return ggml_vk_get_device(device, memoryRequired, std::string(name));
323
+ }
324
+
325
+ bool ggml_vk_has_vulkan() {
326
+ return komputeManager()->hasVulkan();
327
+ }
328
+
329
+ bool ggml_vk_has_device() {
330
+ return komputeManager()->hasDevice();
331
+ }
332
+
333
+ ggml_vk_device ggml_vk_current_device() {
334
+ if (!komputeManager()->hasDevice())
335
+ return ggml_vk_device();
336
+
337
+ auto devices = ggml_vk_available_devices();
338
+ ggml_vk_filterByName(devices, komputeManager()->physicalDevice()->getProperties().deviceName.data());
339
+ GGML_ASSERT(!devices.empty());
340
+ return devices.front();
341
+ }
342
+
343
+ static
344
+ void ggml_vk_allocate_descriptor_pool(struct ggml_kompute_context * ctx, size_t size) {
345
+ std::vector<vk::DescriptorPoolSize> descriptorPoolSizes = {
346
+ vk::DescriptorPoolSize(
347
+ vk::DescriptorType::eStorageBuffer,
348
+ 3 * size // Descriptor count is number of possible tensors to pass into an algorithm
349
+ )
350
+ };
351
+
352
+ vk::DescriptorPoolCreateInfo descriptorPoolInfo(
353
+ vk::DescriptorPoolCreateFlags(),
354
+ size, // Max sets
355
+ static_cast<uint32_t>(descriptorPoolSizes.size()),
356
+ descriptorPoolSizes.data());
357
+
358
+ ctx->pool = std::make_shared<vk::DescriptorPool>();
359
+ vk::Result r = komputeManager()->device()->createDescriptorPool(
360
+ &descriptorPoolInfo, nullptr, ctx->pool.get());
361
+ if (r != vk::Result::eSuccess)
362
+ std::cerr << "Error allocating descriptor pool" << vk::to_string(r);
363
+ }
364
+
365
+ static
366
+ void ggml_vk_free_descriptor_pool(struct ggml_kompute_context * ctx) {
367
+ if (ctx->pool) {
368
+ komputeManager()->device()->destroy(
369
+ *ctx->pool,
370
+ (vk::Optional<const vk::AllocationCallbacks>)nullptr);
371
+ ctx->pool = nullptr;
372
+ }
373
+ }
374
+
375
+ static
376
+ vk::Buffer *ggml_vk_allocate_buffer(size_t size) {
377
+ vk::BufferCreateInfo bufferCreateInfo;
378
+ bufferCreateInfo.size = size;
379
+ bufferCreateInfo.usage = vk::BufferUsageFlagBits::eStorageBuffer |
380
+ vk::BufferUsageFlagBits::eTransferSrc |
381
+ vk::BufferUsageFlagBits::eTransferDst;
382
+ bufferCreateInfo.sharingMode = vk::SharingMode::eExclusive;
383
+
384
+ vk::Buffer *vkBuffer = new vk::Buffer;
385
+ vk::Result r = komputeManager()->device()->createBuffer(&bufferCreateInfo, nullptr, vkBuffer);
386
+ if (r != vk::Result::eSuccess)
387
+ std::cerr << "Error allocating buffer " << vk::to_string(r) << std::endl;
388
+ return vkBuffer;
389
+ }
390
+
391
+ static
392
+ vk::DeviceMemory *ggml_vk_allocate(size_t size, vk::MemoryPropertyFlags flags, vk::MemoryRequirements requirements, bool *isHostVisible) {
393
+
394
+ uint32_t memoryTypeIndex = -1;
395
+ bool memoryTypeIndexFound = false;
396
+ vk::PhysicalDeviceMemoryProperties memoryProperties = komputeManager()->physicalDevice()->getMemoryProperties();
397
+ for (uint32_t i = 0; i < memoryProperties.memoryTypeCount; i++) {
398
+ const vk::MemoryType &memoryType = memoryProperties.memoryTypes[i];
399
+ const vk::MemoryHeap &memoryHeap = memoryProperties.memoryHeaps[memoryType.heapIndex];
400
+ if (memoryHeap.size < size) {
401
+ continue;
402
+ }
403
+
404
+ if (requirements.memoryTypeBits & (1 << i)) {
405
+ if (((memoryProperties.memoryTypes[i]).propertyFlags &
406
+ flags) == flags) {
407
+ memoryTypeIndex = i;
408
+ memoryTypeIndexFound = true;
409
+ if (isHostVisible && (memoryProperties.memoryTypes[i].propertyFlags & vk::MemoryPropertyFlagBits::eHostVisible)) {
410
+ *isHostVisible = true;
411
+ }
412
+ break;
413
+ }
414
+ }
415
+ }
416
+ if (!memoryTypeIndexFound) {
417
+ throw std::runtime_error(
418
+ "Memory type index for buffer creation not found");
419
+ }
420
+
421
+ vk::MemoryAllocateInfo allocInfo;
422
+ allocInfo.allocationSize = size;
423
+ allocInfo.memoryTypeIndex = memoryTypeIndex;
424
+ vk::DeviceMemory *vkDeviceMemory = new vk::DeviceMemory;
425
+ vk::Result r = komputeManager()->device()->allocateMemory(&allocInfo, nullptr, vkDeviceMemory);
426
+ if (r != vk::Result::eSuccess) {
427
+ std::cerr << "Error allocating memory " << vk::to_string(r) << std::endl;
428
+ throw std::runtime_error("Error allocating vulkan memory.");
429
+ }
430
+ return vkDeviceMemory;
431
+ }
432
+
433
+ static size_t ggml_vk_aligned_offset(ggml_backend_buffer_t buffer, size_t offset) {
434
+ size_t minStorageBufferOffsetAlignment = ggml_backend_buffer_get_alignment(buffer);
435
+
436
+ // If offset is already aligned, return it directly
437
+ if (offset % minStorageBufferOffsetAlignment == 0) {
438
+ return offset;
439
+ }
440
+
441
+ // Otherwise, return the largest multiple of minStorageBufferOffsetAlignment less than offset
442
+ return (offset / minStorageBufferOffsetAlignment) * minStorageBufferOffsetAlignment;
443
+ }
444
+
445
+ static ggml_vk_memory ggml_vk_allocate(size_t size) {
446
+ ggml_vk_memory memory;
447
+ bool isHostVisible = false;
448
+ {
449
+ memory.primaryBuffer = ggml_vk_allocate_buffer(size);
450
+ vk::MemoryRequirements memoryRequirements = komputeManager()->device()->getBufferMemoryRequirements(*memory.primaryBuffer);
451
+ vk::MemoryPropertyFlags memoryPropertyFlags = vk::MemoryPropertyFlagBits::eDeviceLocal;
452
+ memory.primaryMemory = ggml_vk_allocate(size, memoryPropertyFlags, memoryRequirements, &isHostVisible);
453
+ komputeManager()->device()->bindBufferMemory(*memory.primaryBuffer, *memory.primaryMemory, 0);
454
+ if (isHostVisible) {
455
+ vk::Result r = komputeManager()->device()->mapMemory(*memory.primaryMemory, 0, size, vk::MemoryMapFlags(), &memory.data);
456
+ if (r != vk::Result::eSuccess)
457
+ std::cerr << "Error mapping memory" << vk::to_string(r);
458
+ }
459
+ }
460
+
461
+ if (!isHostVisible) {
462
+ memory.stagingBuffer = ggml_vk_allocate_buffer(size);
463
+ vk::MemoryRequirements memoryRequirements = komputeManager()->device()->getBufferMemoryRequirements(*memory.stagingBuffer);
464
+ vk::MemoryPropertyFlags memoryPropertyFlags = vk::MemoryPropertyFlagBits::eHostVisible |
465
+ vk::MemoryPropertyFlagBits::eHostCoherent |
466
+ vk::MemoryPropertyFlagBits::eHostCached;
467
+ memory.stagingMemory = ggml_vk_allocate(size, memoryPropertyFlags, memoryRequirements, &isHostVisible);
468
+ komputeManager()->device()->bindBufferMemory(*memory.stagingBuffer, *memory.stagingMemory, 0);
469
+ vk::Result r = komputeManager()->device()->mapMemory(*memory.stagingMemory, 0, size, vk::MemoryMapFlags(), &memory.data);
470
+ if (r != vk::Result::eSuccess)
471
+ std::cerr << "Error mapping memory" << vk::to_string(r);
472
+ }
473
+
474
+ memory.size = size;
475
+ return memory;
476
+ }
477
+
478
+ static void ggml_vk_free_memory(ggml_vk_memory &memory)
479
+ {
480
+ komputeManager()->device()->destroy(
481
+ *memory.primaryBuffer,
482
+ (vk::Optional<const vk::AllocationCallbacks>)nullptr);
483
+ if (memory.stagingBuffer) {
484
+ komputeManager()->device()->destroy(
485
+ *memory.stagingBuffer,
486
+ (vk::Optional<const vk::AllocationCallbacks>)nullptr);
487
+ }
488
+ komputeManager()->device()->freeMemory(
489
+ *memory.primaryMemory,
490
+ (vk::Optional<const vk::AllocationCallbacks>)nullptr);
491
+ if (memory.stagingMemory) {
492
+ komputeManager()->device()->freeMemory(
493
+ *memory.stagingMemory,
494
+ (vk::Optional<const vk::AllocationCallbacks>)nullptr);
495
+ }
496
+ }
497
+
498
+ static const char * ggml_backend_kompute_buffer_type_get_name(ggml_backend_buffer_type_t buft);
499
+
500
+ static
501
+ ggml_vk_memory * ggml_vk_find_tensor(const struct ggml_tensor * t, uint64_t & offset) {
502
+ ggml_backend_buffer_t buffer = t->view_src ? t->view_src->buffer : t->buffer;
503
+
504
+ // compatibility with ggml-backend
505
+ GGML_ASSERT(buffer && buffer->buft->iface.get_name == ggml_backend_kompute_buffer_type_get_name);
506
+
507
+ ggml_vk_memory * buf_ctx = static_cast<ggml_vk_memory *>(buffer->context);
508
+
509
+ const intptr_t ioffs = intptr_t(t->data) - intptr_t(buf_ctx->data);
510
+
511
+ GGML_ASSERT(ioffs >= 0 && ioffs + int64_t(ggml_nbytes(t)) <= int64_t(buffer->size));
512
+
513
+ offset = uint64_t(ioffs);
514
+ return buf_ctx;
515
+ }
516
+
517
+ static
518
+ const std::shared_ptr<kp::Tensor> ggml_vk_get_tensor(const struct ggml_tensor * t, uint32_t * alignedOffset = nullptr) {
519
+ uint64_t originalOffset = 0;
520
+ auto * res = ggml_vk_find_tensor(t, originalOffset);
521
+ if (!res) {
522
+ static std::shared_ptr<kp::Tensor> nullTensor = nullptr;
523
+ return nullTensor;
524
+ }
525
+
526
+ // Create a tensor whose memory will be composed of our buffers at the correct offset
527
+ const size_t nelements = ggml_nelements(t);
528
+ size_t nbytes = ggml_nbytes(t);
529
+
530
+ size_t vulkanOffset = ggml_vk_aligned_offset(t->buffer, originalOffset);
531
+ if (alignedOffset) {
532
+ *alignedOffset = originalOffset - vulkanOffset;
533
+ nbytes += *alignedOffset;
534
+ }
535
+
536
+ return komputeManager()->tensor(
537
+ t->data,
538
+ nelements,
539
+ nbytes, kp::Tensor::TensorDataTypes::eFloat,
540
+ res->primaryMemory, res->primaryBuffer,
541
+ res->stagingMemory, res->stagingBuffer,
542
+ vulkanOffset);
543
+ }
544
+
545
+ static std::vector<uint32_t> getSpirvShader(const unsigned char* rawData, size_t size) {
546
+ if (size % sizeof(uint32_t) != 0) {
547
+ throw std::runtime_error("Invalid size: must be divisible by sizeof(uint32_t)");
548
+ }
549
+
550
+ const uint32_t* data_ptr = reinterpret_cast<const uint32_t*>(rawData);
551
+ size_t count = size / sizeof(uint32_t);
552
+ return std::vector<uint32_t>(data_ptr, data_ptr + count);
553
+ }
554
+
555
+ inline static
556
+ uint32_t safe_divide(uint32_t a, uint32_t b) {
557
+ if (b <= 1) {
558
+ return a;
559
+ }
560
+ if ((a % b) != 0) {
561
+ fprintf(stderr, "((%u %% %u) == %u) != 0\n", a, b, a % b);
562
+ GGML_ABORT("safe_divide result would've had remainder");
563
+ }
564
+ return a / b;
565
+ }
566
+
567
+ static void ggml_vk_add(
568
+ kp::Sequence& seq,
569
+ const std::shared_ptr<kp::Tensor>& inA,
570
+ const std::shared_ptr<kp::Tensor>& inB,
571
+ const std::shared_ptr<kp::Tensor>& out,
572
+ uint32_t inAOff, uint32_t inBOff, uint32_t outOff,
573
+ int32_t ne00, int32_t ne01, int32_t ne02, int32_t ne03,
574
+ int32_t nb00, int32_t nb01, int32_t nb02, int32_t nb03,
575
+ int32_t ne10, int32_t ne11, int32_t ne12, int32_t ne13,
576
+ int32_t nb10, int32_t nb11, int32_t nb12, int32_t nb13,
577
+ int32_t ne0,
578
+ int32_t nb0, int32_t nb1, int32_t nb2, int32_t nb3
579
+ ) {
580
+ const static auto spirv = getSpirvShader(kp::shader_data::op_add_comp_spv,
581
+ kp::shader_data::op_add_comp_spv_len);
582
+
583
+ struct PushConstants {
584
+ uint32_t inAOff, inBOff, outOff;
585
+ int32_t ne00;
586
+ int32_t nb00, nb01, nb02, nb03;
587
+ int32_t ne10, ne11, ne12, ne13;
588
+ int32_t nb10, nb11, nb12, nb13;
589
+ int32_t ne0;
590
+ int32_t nb0, nb1, nb2, nb3;
591
+ } const pushConsts {
592
+ safe_divide(inAOff, 4), safe_divide(inBOff, 4), safe_divide(outOff, 4),
593
+ ne00,
594
+ nb00, nb01, nb02, nb03,
595
+ ne10, ne11, ne12, ne13,
596
+ nb10, nb11, nb12, nb13,
597
+ ne0,
598
+ nb0, nb1, nb2, nb3
599
+ };
600
+
601
+ std::shared_ptr<kp::Algorithm> s_algo = nullptr;
602
+ if (!komputeManager()->hasAlgorithm(__func__)) {
603
+ s_algo = komputeManager()->algorithm<float, PushConstants>(__func__, s_kompute_context->pool.get(), {inA, inB, out}, spirv, {unsigned(ne01), unsigned(ne02), unsigned(ne03)}, {}, {pushConsts});
604
+ } else {
605
+ s_algo = komputeManager()->getAlgorithm(__func__);
606
+ s_algo->setTensors({inA, inB, out});
607
+ s_algo->setWorkgroup({unsigned(ne01), unsigned(ne02), unsigned(ne03)});
608
+ s_algo->setPushConstants<PushConstants>({pushConsts});
609
+ s_algo->updateDescriptors(s_kompute_context->pool.get());
610
+ }
611
+ seq.record<kp::OpAlgoDispatch>(s_algo);
612
+ }
613
+
614
+ static void ggml_vk_addrow(kp::Sequence& seq,
615
+ const std::shared_ptr<kp::Tensor>& inA,
616
+ const std::shared_ptr<kp::Tensor>& inB,
617
+ const std::shared_ptr<kp::Tensor>& out,
618
+ uint32_t inAOff, uint32_t inBOff, uint32_t outOff,
619
+ uint32_t size, uint32_t row = 0) {
620
+
621
+ const static auto spirv = getSpirvShader(kp::shader_data::op_addrow_comp_spv,
622
+ kp::shader_data::op_addrow_comp_spv_len);
623
+
624
+ struct PushConstants {
625
+ uint32_t inAOff, inBOff, outOff;
626
+ uint32_t row;
627
+ } const pushConsts {
628
+ safe_divide(inAOff, 4), safe_divide(inBOff, 4), safe_divide(outOff, 4),
629
+ row
630
+ };
631
+
632
+ std::shared_ptr<kp::Algorithm> s_algo = nullptr;
633
+ if (!komputeManager()->hasAlgorithm(__func__))
634
+ s_algo = komputeManager()->algorithm<float, PushConstants>(__func__, s_kompute_context->pool.get(), {inA, inB, out}, spirv, {size}, {}, {pushConsts});
635
+ else {
636
+ s_algo = komputeManager()->getAlgorithm(__func__);
637
+ s_algo->setTensors({inA, inB, out});
638
+ s_algo->setWorkgroup({size});
639
+ s_algo->setPushConstants<PushConstants>({pushConsts});
640
+ s_algo->updateDescriptors(s_kompute_context->pool.get());
641
+ }
642
+ seq.record<kp::OpAlgoDispatch>(s_algo);
643
+ }
644
+
645
+ static void ggml_vk_mul(
646
+ kp::Sequence& seq,
647
+ const std::shared_ptr<kp::Tensor>& inA,
648
+ const std::shared_ptr<kp::Tensor>& inB,
649
+ const std::shared_ptr<kp::Tensor>& out,
650
+ uint32_t inAOff, uint32_t inBOff, uint32_t outOff,
651
+ int32_t ne00, int32_t ne01, int32_t ne02, int32_t ne03,
652
+ int32_t nb00, int32_t nb01, int32_t nb02, int32_t nb03,
653
+ int32_t ne10, int32_t ne11, int32_t ne12, int32_t ne13,
654
+ int32_t nb10, int32_t nb11, int32_t nb12, int32_t nb13,
655
+ int32_t ne0,
656
+ int32_t nb0, int32_t nb1, int32_t nb2, int32_t nb3
657
+ ) {
658
+ const static auto spirv = getSpirvShader(kp::shader_data::op_mul_comp_spv,
659
+ kp::shader_data::op_mul_comp_spv_len);
660
+
661
+ struct PushConstants {
662
+ uint32_t inAOff, inBOff, outOff;
663
+ int32_t ne00;
664
+ int32_t nb00, nb01, nb02, nb03;
665
+ int32_t ne10, ne11, ne12, ne13;
666
+ int32_t nb10, nb11, nb12, nb13;
667
+ int32_t ne0;
668
+ int32_t nb0, nb1, nb2, nb3;
669
+ } const pushConsts {
670
+ safe_divide(inAOff, 4), safe_divide(inBOff, 4), safe_divide(outOff, 4),
671
+ ne00,
672
+ nb00, nb01, nb02, nb03,
673
+ ne10, ne11, ne12, ne13,
674
+ nb10, nb11, nb12, nb13,
675
+ ne0,
676
+ nb0, nb1, nb2, nb3
677
+ };
678
+
679
+ std::shared_ptr<kp::Algorithm> s_algo = nullptr;
680
+ if (!komputeManager()->hasAlgorithm(__func__)) {
681
+ s_algo = komputeManager()->algorithm<float, PushConstants>(__func__, s_kompute_context->pool.get(), {inA, inB, out}, spirv, {unsigned(ne01), unsigned(ne02), unsigned(ne03)}, {}, {pushConsts});
682
+ } else {
683
+ s_algo = komputeManager()->getAlgorithm(__func__);
684
+ s_algo->setTensors({inA, inB, out});
685
+ s_algo->setWorkgroup({unsigned(ne01), unsigned(ne02), unsigned(ne03)});
686
+ s_algo->setPushConstants<PushConstants>({pushConsts});
687
+ s_algo->updateDescriptors(s_kompute_context->pool.get());
688
+ }
689
+ seq.record<kp::OpAlgoDispatch>(s_algo);
690
+ }
691
+
692
+ static void ggml_vk_scale(kp::Sequence& seq,
693
+ const std::shared_ptr<kp::Tensor>& in,
694
+ const std::shared_ptr<kp::Tensor>& out,
695
+ uint32_t inOff, uint32_t outOff,
696
+ uint32_t size, float scale) {
697
+ const static auto spirv_1 = getSpirvShader(
698
+ kp::shader_data::op_scale_comp_spv, kp::shader_data::op_scale_comp_spv_len
699
+ );
700
+ const static auto spirv_8 = getSpirvShader(
701
+ kp::shader_data::op_scale_8_comp_spv, kp::shader_data::op_scale_8_comp_spv_len
702
+ );
703
+
704
+ struct PushConstants {
705
+ uint32_t inOff, outOff;
706
+ float scale;
707
+ } const pushConsts {
708
+ safe_divide(inOff, 4), safe_divide(outOff, 4),
709
+ scale
710
+ };
711
+
712
+ const auto * spirv = &spirv_1;
713
+ std::string name(__func__);
714
+ if (size % 8 == 0) {
715
+ size /= 8;
716
+ name += "_8";
717
+ spirv = &spirv_8;
718
+ }
719
+
720
+ std::shared_ptr<kp::Algorithm> s_algo = nullptr;
721
+ if (!komputeManager()->hasAlgorithm(name)) {
722
+ s_algo = komputeManager()->algorithm<float, PushConstants>(name, s_kompute_context->pool.get(), {in, out}, *spirv, {size}, {}, {pushConsts});
723
+ } else {
724
+ s_algo = komputeManager()->getAlgorithm(name);
725
+ s_algo->setTensors({in, out});
726
+ s_algo->setWorkgroup({size});
727
+ s_algo->setPushConstants<PushConstants>({pushConsts});
728
+ s_algo->updateDescriptors(s_kompute_context->pool.get());
729
+ }
730
+ seq.record<kp::OpAlgoDispatch>(s_algo);
731
+ }
732
+
733
+ static void ggml_vk_xxlu(
734
+ const std::vector<uint32_t>& spirv, const char * suffix, kp::Sequence& seq,
735
+ const std::shared_ptr<kp::Tensor>& in,
736
+ const std::shared_ptr<kp::Tensor>& out,
737
+ uint32_t inOff, uint32_t outOff,
738
+ uint32_t size
739
+ ) {
740
+ struct PushConstants {
741
+ uint32_t inOff, outOff;
742
+ } const pushConsts {
743
+ safe_divide(inOff, 4), safe_divide(outOff, 4),
744
+ };
745
+
746
+ auto name = std::string(__func__) + "_" + suffix;
747
+ std::shared_ptr<kp::Algorithm> s_algo = nullptr;
748
+ if (!komputeManager()->hasAlgorithm(name)) {
749
+ s_algo = komputeManager()->algorithm<float, PushConstants>(name, s_kompute_context->pool.get(), {in, out}, spirv, {size}, {}, {pushConsts});
750
+ } else {
751
+ s_algo = komputeManager()->getAlgorithm(name);
752
+ s_algo->setTensors({in, out});
753
+ s_algo->setWorkgroup({size});
754
+ s_algo->setPushConstants<PushConstants>({pushConsts});
755
+ s_algo->updateDescriptors(s_kompute_context->pool.get());
756
+ }
757
+ seq.record<kp::OpAlgoDispatch>(s_algo);
758
+ }
759
+
760
+ template <typename... Args>
761
+ static void ggml_vk_silu(Args&&... args) {
762
+ const static auto spirv = getSpirvShader(kp::shader_data::op_silu_comp_spv,
763
+ kp::shader_data::op_silu_comp_spv_len);
764
+
765
+ ggml_vk_xxlu(spirv, "silu", std::forward<Args>(args)...);
766
+ }
767
+
768
+ template <typename... Args>
769
+ static void ggml_vk_relu(Args&&... args) {
770
+ const static auto spirv = getSpirvShader(kp::shader_data::op_relu_comp_spv,
771
+ kp::shader_data::op_relu_comp_spv_len);
772
+
773
+ ggml_vk_xxlu(spirv, "relu", std::forward<Args>(args)...);
774
+ }
775
+
776
+ template <typename... Args>
777
+ static void ggml_vk_gelu(Args&&... args) {
778
+ const static auto spirv = getSpirvShader(kp::shader_data::op_gelu_comp_spv,
779
+ kp::shader_data::op_gelu_comp_spv_len);
780
+
781
+ ggml_vk_xxlu(spirv, "gelu", std::forward<Args>(args)...);
782
+ }
783
+
784
+ static void ggml_vk_soft_max(
785
+ kp::Sequence& seq,
786
+ const std::shared_ptr<kp::Tensor>& inA,
787
+ const std::shared_ptr<kp::Tensor>& inB,
788
+ const std::shared_ptr<kp::Tensor>& out,
789
+ uint32_t inAOff, uint32_t inBOff, uint32_t outOff,
790
+ int32_t ne00, int32_t ne01, int32_t ne02, uint32_t ne03,
791
+ float scale
792
+ ) {
793
+ const static auto spirv = getSpirvShader(kp::shader_data::op_softmax_comp_spv,
794
+ kp::shader_data::op_softmax_comp_spv_len);
795
+
796
+ struct PushConstants {
797
+ uint32_t inAOff, inBOff, outOff;
798
+ int32_t ne00, ne01, ne02;
799
+ float scale;
800
+ int32_t mask;
801
+ } pushConsts {
802
+ safe_divide(inAOff, 4), safe_divide(inBOff, 4), safe_divide(outOff, 4),
803
+ ne00, ne01, ne02,
804
+ scale,
805
+ bool(inB)
806
+ };
807
+
808
+ auto & inB_ = inB ? inB : inA;
809
+
810
+ std::shared_ptr<kp::Algorithm> s_algo = nullptr;
811
+ if (!komputeManager()->hasAlgorithm(__func__)) {
812
+ // FIXME: The softmax kernel needs to be fixed to use the subgroupsize which can vary by device
813
+ const uint32_t local_x = 32;
814
+ s_algo = komputeManager()->algorithm<uint32_t, PushConstants>(__func__, s_kompute_context->pool.get(), {inA, inB_, out}, spirv, {unsigned(ne01), unsigned(ne02), unsigned(ne03)}, {local_x}, {pushConsts});
815
+ } else {
816
+ s_algo = komputeManager()->getAlgorithm(__func__);
817
+ s_algo->setTensors({inA, inB_, out});
818
+ s_algo->setWorkgroup({unsigned(ne01), unsigned(ne02), unsigned(ne03)});
819
+ s_algo->setPushConstants<PushConstants>({pushConsts});
820
+ s_algo->updateDescriptors(s_kompute_context->pool.get());
821
+ }
822
+ seq.record<kp::OpAlgoDispatch>(s_algo);
823
+ }
824
+
825
+ static void ggml_vk_norm_(
826
+ const std::vector<uint32_t>& spirv, const char * suffix, kp::Sequence& seq,
827
+ const std::shared_ptr<kp::Tensor>& in,
828
+ const std::shared_ptr<kp::Tensor>& out,
829
+ uint32_t inOff, uint32_t outOff,
830
+ int32_t ne00, int32_t nb01,
831
+ int32_t nrows, float epsilon
832
+ ) {
833
+ GGML_ASSERT(nb01%sizeof(float) == 0);
834
+ GGML_ASSERT(ne00%sizeof(float) == 0);
835
+
836
+ struct PushConstants {
837
+ uint32_t inOff, outOff;
838
+ uint32_t ne00, nb01;
839
+ float eps;
840
+ } pushConsts {
841
+ safe_divide(inOff, 4), safe_divide(outOff, 4),
842
+ (uint32_t)ne00, (uint32_t)nb01, epsilon
843
+ };
844
+
845
+ auto name = std::string(__func__) + "_" + suffix;
846
+ std::shared_ptr<kp::Algorithm> s_algo = nullptr;
847
+ if (!komputeManager()->hasAlgorithm(name)) {
848
+ s_algo = komputeManager()->algorithm<float, PushConstants>(name, s_kompute_context->pool.get(), {in, out}, spirv, {(uint32_t)nrows}, {}, {pushConsts});
849
+ } else {
850
+ s_algo = komputeManager()->getAlgorithm(name);
851
+ s_algo->setTensors({in, out});
852
+ s_algo->setWorkgroup({(uint32_t)nrows});
853
+ s_algo->setPushConstants<PushConstants>({pushConsts});
854
+ s_algo->updateDescriptors(s_kompute_context->pool.get());
855
+ }
856
+ seq.record<kp::OpAlgoDispatch>(s_algo);
857
+ }
858
+
859
+ template <typename... Args>
860
+ static void ggml_vk_norm(Args&&... args) {
861
+ const static auto spirv = getSpirvShader(kp::shader_data::op_norm_comp_spv,
862
+ kp::shader_data::op_norm_comp_spv_len);
863
+
864
+ ggml_vk_norm_(spirv, "norm", std::forward<Args>(args)...);
865
+ }
866
+
867
+ template <typename... Args>
868
+ static void ggml_vk_rms_norm(Args&&... args) {
869
+ const static auto spirv = getSpirvShader(kp::shader_data::op_rmsnorm_comp_spv,
870
+ kp::shader_data::op_rmsnorm_comp_spv_len);
871
+
872
+ ggml_vk_norm_(spirv, "rms", std::forward<Args>(args)...);
873
+ }
874
+
875
+ static void ggml_vk_diag_mask_inf(kp::Sequence& seq,
876
+ const std::shared_ptr<kp::Tensor>& in,
877
+ const std::shared_ptr<kp::Tensor>& out,
878
+ uint32_t inOff, uint32_t outOff,
879
+ uint32_t n_past,
880
+ int32_t ne00, int32_t ne01, int32_t ne02) {
881
+ const static auto spirv = getSpirvShader(kp::shader_data::op_diagmask_comp_spv,
882
+ kp::shader_data::op_diagmask_comp_spv_len);
883
+
884
+ struct PushConstants {
885
+ uint32_t inOff, outOff;
886
+ uint32_t n_past;
887
+ int32_t ne00, ne01;
888
+ } pushConsts {
889
+ safe_divide(inOff, 4), safe_divide(outOff, 4),
890
+ n_past,
891
+ ne00, ne01
892
+ };
893
+
894
+ std::shared_ptr<kp::Algorithm> s_algo = nullptr;
895
+ if (!komputeManager()->hasAlgorithm(__func__))
896
+ s_algo = komputeManager()->algorithm<float, PushConstants>(__func__, s_kompute_context->pool.get(), {in, out}, spirv, {unsigned(ne00), unsigned(ne01), unsigned(ne02)}, {}, {pushConsts});
897
+ else {
898
+ s_algo = komputeManager()->getAlgorithm(__func__);
899
+ s_algo->setTensors({in, out});
900
+ s_algo->setWorkgroup({unsigned(ne00), unsigned(ne01), unsigned(ne02)});
901
+ s_algo->setPushConstants<PushConstants>({pushConsts});
902
+ s_algo->updateDescriptors(s_kompute_context->pool.get());
903
+ }
904
+ seq.record<kp::OpAlgoDispatch>(s_algo);
905
+ }
906
+
907
+ static void ggml_vk_mul_mat_f16(
908
+ kp::Sequence& seq,
909
+ const std::shared_ptr<kp::Tensor>& inA,
910
+ const std::shared_ptr<kp::Tensor>& inB,
911
+ const std::shared_ptr<kp::Tensor>& out,
912
+ uint32_t inAOff, uint32_t inBOff, uint32_t outOff,
913
+ int32_t ne00, int32_t ne01, int32_t ne02,
914
+ uint32_t nb00, uint32_t nb01, uint32_t nb02,
915
+ int32_t ne10, int32_t ne11, int32_t ne12, int32_t ne13,
916
+ uint32_t nb10, uint32_t nb11, uint32_t nb12,
917
+ int32_t ne0, int32_t ne1,
918
+ uint32_t r2, uint32_t r3
919
+ ) {
920
+ const static auto spirv = getSpirvShader(kp::shader_data::op_mul_mat_f16_comp_spv,
921
+ kp::shader_data::op_mul_mat_f16_comp_spv_len);
922
+
923
+ struct PushConstants {
924
+ uint32_t inAOff, inBOff, outOff;
925
+ int32_t ne00, ne01, ne02;
926
+ uint32_t nb00, nb01, nb02;
927
+ int32_t ne10, ne11, ne12;
928
+ uint32_t nb10, nb11, nb12;
929
+ int32_t ne0, ne1;
930
+ uint32_t r2, r3;
931
+ } pushConsts {
932
+ safe_divide(inAOff, 2), safe_divide(inBOff, 4), safe_divide(outOff, 4),
933
+ ne00, ne01, ne02,
934
+ nb00, nb01, nb02,
935
+ ne10, ne11, ne12,
936
+ nb10, nb11, nb12,
937
+ ne0, ne1,
938
+ r2, r3
939
+ };
940
+
941
+ const unsigned ny = unsigned((ne11 + 4 - 1)/4);
942
+
943
+ std::shared_ptr<kp::Algorithm> s_algo = nullptr;
944
+ if (!komputeManager()->hasAlgorithm(__func__)) {
945
+ const uint32_t local_x = ggml_vk_current_device().subgroupSize * 2;
946
+ s_algo = komputeManager()->algorithm<uint32_t, PushConstants>(__func__, s_kompute_context->pool.get(), {inA, inB, out}, spirv, {unsigned(ne01), ny, unsigned(ne12*ne13)}, {local_x}, {pushConsts});
947
+ } else {
948
+ s_algo = komputeManager()->getAlgorithm(__func__);
949
+ s_algo->setTensors({inA, inB, out});
950
+ s_algo->setWorkgroup({unsigned(ne01), ny, unsigned(ne12*ne13)});
951
+ s_algo->setPushConstants<PushConstants>({pushConsts});
952
+ s_algo->updateDescriptors(s_kompute_context->pool.get());
953
+ }
954
+ seq.record<kp::OpAlgoDispatch>(s_algo);
955
+ }
956
+
957
+ static void ggml_vk_mul_mat_mat_f32(kp::Sequence& seq,
958
+ const std::shared_ptr<kp::Tensor>& inA,
959
+ const std::shared_ptr<kp::Tensor>& inB,
960
+ const std::shared_ptr<kp::Tensor>& out,
961
+ uint32_t inAOff, uint32_t inBOff, uint32_t outOff,
962
+ int32_t ne00, int32_t ne01, int32_t ne02,
963
+ uint32_t nb01, uint32_t nb02,
964
+ int32_t ne11, int32_t ne12,
965
+ uint32_t nb11, uint32_t nb12,
966
+ uint32_t nb1, uint32_t nb2) {
967
+ const static auto spirv = getSpirvShader(kp::shader_data::op_mul_mat_mat_f32_comp_spv,
968
+ kp::shader_data::op_mul_mat_mat_f32_comp_spv_len);
969
+
970
+ struct PushConstants {
971
+ uint32_t inAOff, inBOff, outOff;
972
+ int32_t ne00, ne01, ne02, ne11, ne12;
973
+ uint32_t nb01, nb02;
974
+ uint32_t nb11, nb12;
975
+ uint32_t nb1, nb2;
976
+ } pushConsts {
977
+ safe_divide(inAOff, 4), safe_divide(inBOff, 4), safe_divide(outOff, 4),
978
+ ne00, ne01, ne02, ne11, ne12,
979
+ nb01, nb02, nb11, nb12,
980
+ nb1, nb2
981
+ };
982
+
983
+ const uint32_t local_x = ggml_vk_current_device().subgroupSize;
984
+ std::shared_ptr<kp::Algorithm> s_algo = nullptr;
985
+ if (!komputeManager()->hasAlgorithm(__func__)) {
986
+ s_algo = komputeManager()->algorithm<uint32_t, PushConstants>(__func__, s_kompute_context->pool.get(),
987
+ {inA, inB, out}, spirv,
988
+ {unsigned(ne01),
989
+ unsigned(ne11),
990
+ unsigned(std::max(ne12, ne02))
991
+ },
992
+ {local_x},
993
+ {pushConsts});
994
+ } else {
995
+ s_algo = komputeManager()->getAlgorithm(__func__);
996
+ s_algo->setTensors({inA, inB, out});
997
+ s_algo->setWorkgroup({unsigned(ne01),
998
+ unsigned(ne11),
999
+ unsigned(std::max(ne12, ne02)),
1000
+ });
1001
+ s_algo->setPushConstants<PushConstants>({pushConsts});
1002
+ s_algo->updateDescriptors(s_kompute_context->pool.get());
1003
+ }
1004
+ seq.record<kp::OpAlgoDispatch>(s_algo);
1005
+ }
1006
+
1007
+ static void ggml_vk_mul_mat_impl(
1008
+ const std::vector<uint32_t>& spirv, const char * suffix, uint32_t block_size, kp::Sequence& seq,
1009
+ const std::shared_ptr<kp::Tensor>& inA,
1010
+ const std::shared_ptr<kp::Tensor>& inB,
1011
+ const std::shared_ptr<kp::Tensor>& out,
1012
+ uint32_t inAOff, uint32_t inBOff, uint32_t outOff,
1013
+ int32_t ne00, int32_t ne01, int32_t ne02,
1014
+ int32_t ne10, int32_t ne11, int32_t ne12, int32_t ne13,
1015
+ int32_t ne0, int32_t ne1,
1016
+ uint32_t r2, uint32_t r3
1017
+ ) {
1018
+ struct PushConstants {
1019
+ uint32_t inAOff, inBOff, outOff;
1020
+ int32_t ne00, ne01, ne02;
1021
+ int32_t ne10, ne12;
1022
+ int32_t ne0, ne1;
1023
+ uint32_t r2, r3;
1024
+ } pushConsts {
1025
+ safe_divide(inAOff, block_size), safe_divide(inBOff, 4), safe_divide(outOff, 4),
1026
+ ne00, ne01, ne02,
1027
+ ne10, ne12,
1028
+ ne0, ne1,
1029
+ r2, r3
1030
+ };
1031
+
1032
+ auto name = std::string(__func__) + "_" + suffix;
1033
+ std::shared_ptr<kp::Algorithm> s_algo = nullptr;
1034
+ if (!komputeManager()->hasAlgorithm(name)) {
1035
+ const uint32_t local_x = ggml_vk_current_device().subgroupSize * 2;
1036
+ s_algo = komputeManager()->algorithm<uint32_t, PushConstants>(name, s_kompute_context->pool.get(), {inA, inB, out}, spirv, {unsigned((ne01 + 7)/8), unsigned(ne11), unsigned(ne12*ne13)}, {local_x}, {pushConsts});
1037
+ } else {
1038
+ s_algo = komputeManager()->getAlgorithm(name);
1039
+ s_algo->setTensors({inA, inB, out});
1040
+ s_algo->setWorkgroup({unsigned((ne01 + 7)/8), unsigned(ne11), unsigned(ne12*ne13)});
1041
+ s_algo->setPushConstants<PushConstants>({pushConsts});
1042
+ s_algo->updateDescriptors(s_kompute_context->pool.get());
1043
+ }
1044
+ seq.record<kp::OpAlgoDispatch>(s_algo);
1045
+ }
1046
+
1047
+ template <typename... Args>
1048
+ static void ggml_vk_mul_mat_q4_0(Args&&... args) {
1049
+ const static auto spirv = getSpirvShader(kp::shader_data::op_mul_mat_q4_0_comp_spv,
1050
+ kp::shader_data::op_mul_mat_q4_0_comp_spv_len);
1051
+
1052
+ ggml_vk_mul_mat_impl(spirv, "q4_0", 1/*We access blocks unaligned*/, std::forward<Args>(args)...);
1053
+ }
1054
+
1055
+ template <typename... Args>
1056
+ static void ggml_vk_mul_mat_q4_1(Args&&... args) {
1057
+ const static auto spirv = getSpirvShader(kp::shader_data::op_mul_mat_q4_1_comp_spv,
1058
+ kp::shader_data::op_mul_mat_q4_1_comp_spv_len);
1059
+
1060
+ ggml_vk_mul_mat_impl(spirv, "q4_1", 1/*We access blocks unaligned*/, std::forward<Args>(args)...);
1061
+ }
1062
+
1063
+ template <typename... Args>
1064
+ static void ggml_vk_mul_mat_q8_0(Args&&... args) {
1065
+ const static auto spirv = getSpirvShader(kp::shader_data::op_mul_mat_q8_0_comp_spv,
1066
+ kp::shader_data::op_mul_mat_q8_0_comp_spv_len);
1067
+
1068
+ ggml_vk_mul_mat_impl(spirv, "q8_0", 1/*We access blocks unaligned*/, std::forward<Args>(args)...);
1069
+ }
1070
+
1071
+ static void ggml_vk_mul_mat_q4_k(
1072
+ kp::Sequence& seq,
1073
+ const std::shared_ptr<kp::Tensor>& inA,
1074
+ const std::shared_ptr<kp::Tensor>& inB,
1075
+ const std::shared_ptr<kp::Tensor>& out,
1076
+ uint32_t inAOff, uint32_t inBOff, uint32_t outOff,
1077
+ int32_t ne00, int32_t ne01, int32_t ne02, int32_t ne10,
1078
+ int32_t ne11, int32_t ne12, int32_t ne13, int32_t ne0,
1079
+ int32_t ne1, int32_t r2, int32_t r3
1080
+ ) {
1081
+ const static auto spirv = getSpirvShader(kp::shader_data::op_mul_mat_q4_k_comp_spv,
1082
+ kp::shader_data::op_mul_mat_q4_k_comp_spv_len);
1083
+
1084
+ struct PushConstants {
1085
+ uint32_t inAOff, inBOff, outOff;
1086
+ int32_t ne00, ne10, ne0, ne1, ne01, ne02, ne12, r2, r3;
1087
+ } pushConsts {
1088
+ 0, 0, 0,
1089
+ ne00, ne10, ne0, ne1, ne01, ne02, ne12, r2, r3
1090
+ };
1091
+
1092
+ std::shared_ptr<kp::Algorithm> s_algo = nullptr;
1093
+ if (!komputeManager()->hasAlgorithm(__func__)) {
1094
+ s_algo = komputeManager()->algorithm<uint32_t, PushConstants>(__func__, s_kompute_context->pool.get(), {inA, inB, out}, spirv, {unsigned((ne01 + 3)/4), unsigned(ne11), unsigned(ne12) * unsigned(ne13)}, {}, {pushConsts});
1095
+ } else {
1096
+ s_algo = komputeManager()->getAlgorithm(__func__);
1097
+ s_algo->setTensors({inA, inB, out});
1098
+ s_algo->setWorkgroup({unsigned((ne01 + 3)/4), unsigned(ne11), unsigned(ne12) * unsigned(ne13)});
1099
+ s_algo->setPushConstants<PushConstants>({pushConsts});
1100
+ s_algo->updateDescriptors(s_kompute_context->pool.get());
1101
+ }
1102
+ seq.record<kp::OpAlgoDispatch>(s_algo);
1103
+ }
1104
+
1105
+ static void ggml_vk_mul_mat_q6_k(
1106
+ kp::Sequence& seq,
1107
+ const std::shared_ptr<kp::Tensor>& inA,
1108
+ const std::shared_ptr<kp::Tensor>& inB,
1109
+ const std::shared_ptr<kp::Tensor>& out,
1110
+ uint32_t inAOff, uint32_t inBOff, uint32_t outOff,
1111
+ int32_t ne00, int32_t ne10, int32_t ne0, int32_t ne1,
1112
+ int32_t ne01, int32_t ne11, int32_t ne12, int32_t ne02
1113
+ ) {
1114
+ const static auto spirv = getSpirvShader(kp::shader_data::op_mul_mat_q6_k_comp_spv,
1115
+ kp::shader_data::op_mul_mat_q6_k_comp_spv_len);
1116
+
1117
+ struct PushConstants {
1118
+ uint32_t inAOff, inBOff, outOff;
1119
+ int32_t ne00, ne10, ne0, ne1, ne01, gqa;
1120
+ } pushConsts {
1121
+ inAOff, safe_divide(inBOff, 4), safe_divide(outOff, 4),
1122
+ ne00, ne10, ne0, ne1, ne01, ne12/ne02
1123
+ };
1124
+
1125
+ std::shared_ptr<kp::Algorithm> s_algo = nullptr;
1126
+ if (!komputeManager()->hasAlgorithm(__func__)) {
1127
+ const uint32_t local_x = ggml_vk_current_device().subgroupSize * 2;
1128
+ s_algo = komputeManager()->algorithm<uint32_t, PushConstants>(__func__, s_kompute_context->pool.get(), {inA, inB, out}, spirv, {unsigned((ne01 + 1)/2), unsigned(ne11), unsigned(ne12)}, {local_x}, {pushConsts});
1129
+ } else {
1130
+ s_algo = komputeManager()->getAlgorithm(__func__);
1131
+ s_algo->setTensors({inA, inB, out});
1132
+ s_algo->setWorkgroup({unsigned((ne01 + 1)/2), unsigned(ne11), unsigned(ne12)});
1133
+ s_algo->setPushConstants<PushConstants>({pushConsts});
1134
+ s_algo->updateDescriptors(s_kompute_context->pool.get());
1135
+ }
1136
+ seq.record<kp::OpAlgoDispatch>(s_algo);
1137
+ }
1138
+
1139
+ static void ggml_vk_get_rows(
1140
+ const std::vector<uint32_t>& spirv,
1141
+ const char * suffix,
1142
+ unsigned element_size, unsigned qk,
1143
+ kp::Sequence& seq,
1144
+ const std::shared_ptr<kp::Tensor>& inA,
1145
+ const std::shared_ptr<kp::Tensor>& inB,
1146
+ const std::shared_ptr<kp::Tensor>& out,
1147
+ uint32_t inAOff, uint32_t inBOff, uint32_t outOff,
1148
+ int32_t ne00, int32_t nb01, int32_t nb1,
1149
+ uint32_t size
1150
+ ) {
1151
+ GGML_ASSERT(nb01%element_size == 0);
1152
+ GGML_ASSERT(nb1%sizeof(float) == 0);
1153
+ if (qk) GGML_ASSERT(ne00%qk == 0);
1154
+
1155
+ struct PushConstants {
1156
+ uint32_t inAOff, inBOff, outOff;
1157
+ int32_t ne00, nb01, nb1;
1158
+ } pushConsts {
1159
+ safe_divide(inAOff, element_size), safe_divide(inBOff, 4), safe_divide(outOff, 4),
1160
+ ne00, nb01, nb1
1161
+ };
1162
+
1163
+ auto name = std::string(__func__) + "_" + suffix;
1164
+ std::shared_ptr<kp::Algorithm> s_algo = nullptr;
1165
+ if (!komputeManager()->hasAlgorithm(name)) {
1166
+ s_algo = komputeManager()->algorithm<float, PushConstants>(name, s_kompute_context->pool.get(), {inA, inB, out}, spirv, {size}, {}, {pushConsts});
1167
+ } else {
1168
+ s_algo = komputeManager()->getAlgorithm(name);
1169
+ s_algo->setTensors({inA, inB, out});
1170
+ s_algo->setWorkgroup({size});
1171
+ s_algo->setPushConstants<PushConstants>({pushConsts});
1172
+ s_algo->updateDescriptors(s_kompute_context->pool.get());
1173
+ }
1174
+ seq.record<kp::OpAlgoDispatch>(s_algo);
1175
+ }
1176
+
1177
+ template <typename... Args>
1178
+ static void ggml_vk_get_rows_f32(Args&&... args) {
1179
+ const static auto spirv = getSpirvShader(kp::shader_data::op_getrows_f32_comp_spv,
1180
+ kp::shader_data::op_getrows_f32_comp_spv_len);
1181
+
1182
+ ggml_vk_get_rows(spirv, "f32", sizeof(float), 0, std::forward<Args>(args)...);
1183
+ }
1184
+
1185
+ template <typename... Args>
1186
+ static void ggml_vk_get_rows_f16(Args&&... args) {
1187
+ const static auto spirv = getSpirvShader(kp::shader_data::op_getrows_f16_comp_spv,
1188
+ kp::shader_data::op_getrows_f16_comp_spv_len);
1189
+
1190
+ ggml_vk_get_rows(spirv, "f16", sizeof(half), 0, std::forward<Args>(args)...);
1191
+ }
1192
+
1193
+ template <typename... Args>
1194
+ static void ggml_vk_get_rows_q4_0(Args&&... args) {
1195
+ const static auto spirv = getSpirvShader(kp::shader_data::op_getrows_q4_0_comp_spv,
1196
+ kp::shader_data::op_getrows_q4_0_comp_spv_len);
1197
+
1198
+ ggml_vk_get_rows(spirv, "q4_0", 1/*We access blocks unaligned*/, QK4_0, std::forward<Args>(args)...);
1199
+ }
1200
+
1201
+ template <typename... Args>
1202
+ static void ggml_vk_get_rows_q4_1(Args&&... args) {
1203
+ const static auto spirv = getSpirvShader(kp::shader_data::op_getrows_q4_1_comp_spv,
1204
+ kp::shader_data::op_getrows_q4_1_comp_spv_len);
1205
+
1206
+ ggml_vk_get_rows(spirv, "q4_1", 1/*We access blocks unaligned*/, QK4_1, std::forward<Args>(args)...);
1207
+ }
1208
+
1209
+ template <typename... Args>
1210
+ static void ggml_vk_get_rows_q6_k(Args&&... args) {
1211
+ const static auto spirv = getSpirvShader(kp::shader_data::op_getrows_q6_k_comp_spv,
1212
+ kp::shader_data::op_getrows_q6_k_comp_spv_len);
1213
+ ggml_vk_get_rows(spirv, "q6_k", 1/*We access blocks unaligned*/, QK_NL, std::forward<Args>(args)...);
1214
+ }
1215
+
1216
+ static void ggml_vk_rope(
1217
+ kp::Sequence& seq,
1218
+ const std::shared_ptr<kp::Tensor>& inA,
1219
+ const std::shared_ptr<kp::Tensor>& inB,
1220
+ const std::shared_ptr<kp::Tensor>& out,
1221
+ uint32_t inAOff, uint32_t inBOff, uint32_t outOff,
1222
+ ggml_type src0t, int32_t n_dims, int32_t mode, int32_t n_ctx_orig,
1223
+ float freq_base, float freq_scale, float ext_factor, float attn_factor, float beta_fast, float beta_slow,
1224
+ int32_t ne01, int32_t ne02, int32_t ne03,
1225
+ uint32_t nb00, uint32_t nb01, uint32_t nb02, uint32_t nb03,
1226
+ int32_t ne0,
1227
+ uint32_t nb0, uint32_t nb1, uint32_t nb2, uint32_t nb3
1228
+ ) {
1229
+ GGML_ASSERT(src0t == GGML_TYPE_F16 || src0t == GGML_TYPE_F32);
1230
+
1231
+ static const auto spirv_f16 = getSpirvShader(
1232
+ kp::shader_data::op_rope_f16_comp_spv, kp::shader_data::op_rope_f16_comp_spv_len
1233
+ );
1234
+ static const auto spirv_f32 = getSpirvShader(
1235
+ kp::shader_data::op_rope_f32_comp_spv, kp::shader_data::op_rope_f32_comp_spv_len
1236
+ );
1237
+
1238
+ int type_size = src0t == GGML_TYPE_F16 ? 2 : 4;
1239
+
1240
+ GGML_ASSERT(nb03 % type_size == 0);
1241
+ GGML_ASSERT(nb02 % type_size == 0);
1242
+ GGML_ASSERT(nb01 % type_size == 0);
1243
+ GGML_ASSERT(nb00 % type_size == 0);
1244
+ GGML_ASSERT(nb3 % type_size == 0);
1245
+ GGML_ASSERT(nb2 % type_size == 0);
1246
+ GGML_ASSERT(nb1 % type_size == 0);
1247
+ GGML_ASSERT(nb0 % type_size == 0);
1248
+
1249
+ struct PushConstants {
1250
+ uint32_t inAOff, inBOff, outOff;
1251
+ int32_t n_dims, mode, n_ctx_orig;
1252
+ float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
1253
+ uint32_t nb00, nb01, nb02, nb03;
1254
+ int32_t ne0;
1255
+ uint32_t nb0, nb1, nb2, nb3;
1256
+ } pushConsts {
1257
+ safe_divide(inAOff, type_size), safe_divide(inBOff, 4), safe_divide(outOff, type_size),
1258
+ n_dims, mode, n_ctx_orig,
1259
+ freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow,
1260
+ nb00, nb01, nb02, nb03,
1261
+ ne0,
1262
+ nb0, nb1, nb2, nb3
1263
+ };
1264
+
1265
+ auto name = std::string(__func__) + (src0t == GGML_TYPE_F16 ? "_f16" : "_f32");
1266
+ std::shared_ptr<kp::Algorithm> s_algo = nullptr;
1267
+ if (!komputeManager()->hasAlgorithm(name)) {
1268
+ s_algo = komputeManager()->algorithm<float, PushConstants>(
1269
+ name, s_kompute_context->pool.get(), {inA, inB, out},
1270
+ src0t == GGML_TYPE_F16 ? spirv_f16 : spirv_f32,
1271
+ {unsigned(ne01), unsigned(ne02), unsigned(ne03)}, {}, {pushConsts}
1272
+ );
1273
+ } else {
1274
+ s_algo = komputeManager()->getAlgorithm(name);
1275
+ s_algo->setTensors({inA, inB, out});
1276
+ s_algo->setWorkgroup({unsigned(ne01), unsigned(ne02), unsigned(ne03)});
1277
+ s_algo->setPushConstants<PushConstants>({pushConsts});
1278
+ s_algo->updateDescriptors(s_kompute_context->pool.get());
1279
+ }
1280
+ seq.record<kp::OpAlgoDispatch>(s_algo);
1281
+ }
1282
+
1283
+ static void ggml_vk_cpy(
1284
+ const std::vector<uint32_t>& spirv,
1285
+ uint32_t in_element_size, uint32_t out_element_size,
1286
+ kp::Sequence& seq,
1287
+ const std::shared_ptr<kp::Tensor>& in,
1288
+ const std::shared_ptr<kp::Tensor>& out,
1289
+ uint32_t inOff, uint32_t outOff,
1290
+ int32_t ne00, int32_t ne01, int32_t ne02, int32_t ne03,
1291
+ uint32_t nb00, uint32_t nb01, uint32_t nb02, uint32_t nb03,
1292
+ int32_t ne0, int32_t ne1, int32_t ne2,
1293
+ uint32_t nb0, uint32_t nb1, uint32_t nb2, uint32_t nb3
1294
+ ) {
1295
+ struct PushConstants {
1296
+ uint32_t inOff, outOff;
1297
+ int32_t ne00, ne01, ne02;
1298
+ uint32_t nb00, nb01, nb02, nb03;
1299
+ int32_t ne0, ne1, ne2;
1300
+ uint32_t nb0, nb1, nb2, nb3;
1301
+ } pushConsts {
1302
+ safe_divide(inOff, in_element_size), safe_divide(outOff, out_element_size),
1303
+ ne00, ne01, ne02,
1304
+ nb00, nb01, nb02, nb03,
1305
+ ne0, ne1, ne2,
1306
+ nb0, nb1, nb2, nb3
1307
+ };
1308
+
1309
+ std::string name = std::string(__func__)
1310
+ + "_i_" + std::to_string(in_element_size)
1311
+ + "_o_" + std::to_string(out_element_size);
1312
+ std::shared_ptr<kp::Algorithm> s_algo = nullptr;
1313
+ if (!komputeManager()->hasAlgorithm(name))
1314
+ s_algo = komputeManager()->algorithm<float, PushConstants>(name, s_kompute_context->pool.get(), {in, out}, spirv, {unsigned(ne01), unsigned(ne02), unsigned(ne03)}, {}, {pushConsts});
1315
+ else {
1316
+ s_algo = komputeManager()->getAlgorithm(name);
1317
+ s_algo->setTensors({in, out});
1318
+ s_algo->setWorkgroup({unsigned(ne01), unsigned(ne02), unsigned(ne03)});
1319
+ s_algo->setPushConstants<PushConstants>({pushConsts});
1320
+ s_algo->updateDescriptors(s_kompute_context->pool.get());
1321
+ }
1322
+ seq.record<kp::OpAlgoDispatch>(s_algo);
1323
+ }
1324
+
1325
+ template <typename... Args>
1326
+ static void ggml_vk_cpy_f32_f16(Args&&... args) {
1327
+ const static auto spirv = getSpirvShader(kp::shader_data::op_cpy_f32_f16_comp_spv,
1328
+ kp::shader_data::op_cpy_f32_f16_comp_spv_len);
1329
+ ggml_vk_cpy(spirv, 4, 2, std::forward<Args>(args)...);
1330
+ }
1331
+
1332
+ template <typename... Args>
1333
+ static void ggml_vk_cpy_f32_f32(Args&&... args) {
1334
+ const static auto spirv = getSpirvShader(kp::shader_data::op_cpy_f32_f32_comp_spv,
1335
+ kp::shader_data::op_cpy_f32_f32_comp_spv_len);
1336
+ ggml_vk_cpy(spirv, 4, 4, std::forward<Args>(args)...);
1337
+ }
1338
+
1339
+ template <typename... Args>
1340
+ static void ggml_vk_cpy_f16_f16(Args&&... args) {
1341
+ const static auto spirv = getSpirvShader(kp::shader_data::op_cpy_f16_f16_comp_spv,
1342
+ kp::shader_data::op_cpy_f16_f16_comp_spv_len);
1343
+ ggml_vk_cpy(spirv, 2, 2, std::forward<Args>(args)...);
1344
+ }
1345
+
1346
+ template <typename... Args>
1347
+ static void ggml_vk_cpy_f16_f32(Args&&... args) {
1348
+ const static auto spirv = getSpirvShader(kp::shader_data::op_cpy_f16_f32_comp_spv,
1349
+ kp::shader_data::op_cpy_f16_f32_comp_spv_len);
1350
+ ggml_vk_cpy(spirv, 2, 4, std::forward<Args>(args)...);
1351
+ }
1352
+
1353
+ static bool ggml_backend_kompute_device_supports_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) {
1354
+ switch (op->op) {
1355
+ case GGML_OP_UNARY:
1356
+ switch (ggml_get_unary_op(op)) {
1357
+ case GGML_UNARY_OP_RELU:
1358
+ case GGML_UNARY_OP_GELU:
1359
+ case GGML_UNARY_OP_SILU:
1360
+ return ggml_is_contiguous(op->src[0]);
1361
+ default:
1362
+ ;
1363
+ }
1364
+ break;
1365
+ case GGML_OP_NONE:
1366
+ case GGML_OP_RESHAPE:
1367
+ case GGML_OP_VIEW:
1368
+ case GGML_OP_TRANSPOSE:
1369
+ case GGML_OP_PERMUTE:
1370
+ case GGML_OP_ADD:
1371
+ case GGML_OP_MUL:
1372
+ case GGML_OP_SCALE:
1373
+ case GGML_OP_SOFT_MAX:
1374
+ case GGML_OP_RMS_NORM:
1375
+ case GGML_OP_NORM:
1376
+ case GGML_OP_ROPE:
1377
+ return true;
1378
+ case GGML_OP_DUP:
1379
+ case GGML_OP_CPY:
1380
+ case GGML_OP_CONT:
1381
+ switch (op->src[0]->type) {
1382
+ case GGML_TYPE_F32:
1383
+ case GGML_TYPE_F16:
1384
+ break;
1385
+ default:
1386
+ return false;
1387
+ }
1388
+ switch (op->type) {
1389
+ case GGML_TYPE_F32:
1390
+ case GGML_TYPE_F16:
1391
+ break;
1392
+ default:
1393
+ return false;
1394
+ }
1395
+ return true;
1396
+ case GGML_OP_DIAG_MASK_INF:
1397
+ return op->ne[3] == 1;
1398
+ case GGML_OP_GET_ROWS:
1399
+ switch (op->src[0]->type) {
1400
+ case GGML_TYPE_F32:
1401
+ case GGML_TYPE_F16:
1402
+ case GGML_TYPE_Q4_0:
1403
+ case GGML_TYPE_Q4_1:
1404
+ case GGML_TYPE_Q6_K:
1405
+ return op->ne[2] == 1 && op->ne[3] == 1;
1406
+ default:
1407
+ ;
1408
+ }
1409
+ return false;
1410
+ case GGML_OP_MUL_MAT:
1411
+ if (op->src[1]->type != GGML_TYPE_F32 || ggml_is_transposed(op->src[0]) || ggml_is_transposed(op->src[1]))
1412
+ return false;
1413
+
1414
+ switch (op->src[0]->type) {
1415
+ case GGML_TYPE_F32:
1416
+ case GGML_TYPE_Q6_K:
1417
+ return op->ne[3] == 1;
1418
+ case GGML_TYPE_F16:
1419
+ case GGML_TYPE_Q8_0:
1420
+ case GGML_TYPE_Q4_0:
1421
+ case GGML_TYPE_Q4_1:
1422
+ case GGML_TYPE_Q4_K:
1423
+ return true;
1424
+ default:
1425
+ ;
1426
+ }
1427
+ default:
1428
+ ;
1429
+ }
1430
+ return false;
1431
+
1432
+ GGML_UNUSED(dev);
1433
+ }
1434
+
1435
+ static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml_cgraph * gf) {
1436
+ const int n_seq = 8;
1437
+
1438
+ // FIXME: Figure out if we can somehow optimize the size of the pool... right now we're setting
1439
+ // it to the size of the graph, but I think it can be made smaller?
1440
+ ggml_vk_allocate_descriptor_pool(ctx, gf->n_nodes);
1441
+
1442
+ std::vector<std::shared_ptr<kp::Sequence>> sequences(n_seq);
1443
+
1444
+ for (auto& sequence : sequences) {
1445
+ sequence = komputeManager()->sequence();
1446
+ }
1447
+ for (int seq_idx = 0; seq_idx < n_seq; ++seq_idx) {
1448
+ const int n_nodes_per_seq = (gf->n_nodes + n_seq - 1) / n_seq;
1449
+
1450
+ auto& seq = *sequences[seq_idx];
1451
+
1452
+ const int node_start = (seq_idx + 0) * n_nodes_per_seq;
1453
+ const int node_end = std::min((seq_idx == n_seq - 1) ? gf->n_nodes : (seq_idx + 1) * n_nodes_per_seq, gf->n_nodes);
1454
+
1455
+ bool any_commands_recorded = false;
1456
+
1457
+ for (int i = node_start; i < node_end; ++i) {
1458
+ struct ggml_tensor * src0 = gf->nodes[i]->src[0];
1459
+ struct ggml_tensor * src1 = gf->nodes[i]->src[1];
1460
+ struct ggml_tensor * src2 = gf->nodes[i]->src[2]; GGML_UNUSED(src2);
1461
+ struct ggml_tensor * dst = gf->nodes[i];
1462
+ GGML_ASSERT(dst->data != nullptr);
1463
+
1464
+ if (ggml_is_empty(dst)) {
1465
+ continue;
1466
+ }
1467
+
1468
+ switch (dst->op) {
1469
+ case GGML_OP_NONE:
1470
+ case GGML_OP_RESHAPE:
1471
+ case GGML_OP_VIEW:
1472
+ case GGML_OP_TRANSPOSE:
1473
+ case GGML_OP_PERMUTE:
1474
+ continue; // noop -> next node
1475
+ default:
1476
+ break;
1477
+ }
1478
+
1479
+ any_commands_recorded = true;
1480
+
1481
+ const int32_t ne00 = src0 ? src0->ne[0] : 0;
1482
+ const int32_t ne01 = src0 ? src0->ne[1] : 0;
1483
+ const int32_t ne02 = src0 ? src0->ne[2] : 0;
1484
+ const int32_t ne03 = src0 ? src0->ne[3] : 0;
1485
+
1486
+ const uint32_t nb00 = src0 ? src0->nb[0] : 0;
1487
+ const uint32_t nb01 = src0 ? src0->nb[1] : 0;
1488
+ const uint32_t nb02 = src0 ? src0->nb[2] : 0;
1489
+ const uint32_t nb03 = src0 ? src0->nb[3] : 0;
1490
+
1491
+ const int32_t ne10 = src1 ? src1->ne[0] : 0;
1492
+ const int32_t ne11 = src1 ? src1->ne[1] : 0;
1493
+ const int32_t ne12 = src1 ? src1->ne[2] : 0;
1494
+ const int32_t ne13 = src1 ? src1->ne[3] : 0;
1495
+
1496
+ const uint32_t nb10 = src1 ? src1->nb[0] : 0;
1497
+ const uint32_t nb11 = src1 ? src1->nb[1] : 0;
1498
+ const uint32_t nb12 = src1 ? src1->nb[2] : 0;
1499
+ const uint32_t nb13 = src1 ? src1->nb[3] : 0;
1500
+
1501
+ const int32_t ne0 = dst ? dst->ne[0] : 0;
1502
+ const int32_t ne1 = dst ? dst->ne[1] : 0;
1503
+ const int32_t ne2 = dst ? dst->ne[2] : 0;
1504
+ // const int32_t ne3 = dst ? dst->ne[3] : 0;
1505
+
1506
+ const uint32_t nb0 = dst ? dst->nb[0] : 0;
1507
+ const uint32_t nb1 = dst ? dst->nb[1] : 0;
1508
+ const uint32_t nb2 = dst ? dst->nb[2] : 0;
1509
+ const uint32_t nb3 = dst ? dst->nb[3] : 0;
1510
+
1511
+ const enum ggml_type src0t = src0 ? src0->type : GGML_TYPE_COUNT;
1512
+ const enum ggml_type src1t = src1 ? src1->type : GGML_TYPE_COUNT;
1513
+ const enum ggml_type dstt = dst ? dst->type : GGML_TYPE_COUNT;
1514
+
1515
+ const static std::shared_ptr<kp::Tensor> nullTensor = nullptr;
1516
+ uint32_t off_src0 = 0;
1517
+ uint32_t off_src1 = 0;
1518
+ uint32_t off_dst = 0;
1519
+ const std::shared_ptr<kp::Tensor>& id_src0 = src0 ? ggml_vk_get_tensor(src0, &off_src0) : nullTensor;
1520
+ const std::shared_ptr<kp::Tensor>& id_src1 = src1 ? ggml_vk_get_tensor(src1, &off_src1) : nullTensor;
1521
+ const std::shared_ptr<kp::Tensor>& id_dst = dst ? ggml_vk_get_tensor(dst, &off_dst) : nullTensor;
1522
+
1523
+ switch (dst->op) {
1524
+ case GGML_OP_ADD:
1525
+ {
1526
+ if (ggml_nelements(src1) == ne10 && ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0) {
1527
+ // src1 is a row
1528
+ ggml_vk_addrow(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ggml_nelements(dst)/4, ne00);
1529
+ } else {
1530
+ ggml_vk_add(
1531
+ seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst,
1532
+ ne00, ne01, ne02, ne03,
1533
+ nb00, nb01, nb02, nb03,
1534
+ ne10, ne11, ne12, ne13,
1535
+ nb10, nb11, nb12, nb13,
1536
+ ne0,
1537
+ nb0, nb1, nb2, nb3
1538
+ );
1539
+ }
1540
+ } break;
1541
+ case GGML_OP_MUL:
1542
+ {
1543
+ ggml_vk_mul(
1544
+ seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst,
1545
+ ne00, ne01, ne02, ne03,
1546
+ nb00, nb01, nb02, nb03,
1547
+ ne10, ne11, ne12, ne13,
1548
+ nb10, nb11, nb12, nb13,
1549
+ ne0,
1550
+ nb0, nb1, nb2, nb3
1551
+ );
1552
+ } break;
1553
+ case GGML_OP_SCALE:
1554
+ {
1555
+ float scale; memcpy(&scale, dst->op_params, sizeof(float));
1556
+
1557
+ ggml_vk_scale(seq, id_src0, id_dst, off_src0, off_dst, ggml_nelements(dst), scale);
1558
+ } break;
1559
+ case GGML_OP_UNARY:
1560
+ {
1561
+ int64_t n = ggml_nelements(dst);
1562
+ GGML_ASSERT(n % 4 == 0);
1563
+ switch (ggml_get_unary_op(gf->nodes[i])) {
1564
+ case GGML_UNARY_OP_SILU:
1565
+ {
1566
+ ggml_vk_silu(seq, id_src0, id_dst, off_src0, off_dst, n/4);
1567
+ } break;
1568
+ case GGML_UNARY_OP_RELU:
1569
+ {
1570
+ ggml_vk_relu(seq, id_src0, id_dst, off_src0, off_dst, n/4);
1571
+ } break;
1572
+ case GGML_UNARY_OP_GELU:
1573
+ {
1574
+ GGML_ASSERT(n % 8 == 0);
1575
+ ggml_vk_gelu(seq, id_src0, id_dst, off_src0, off_dst, n/8);
1576
+ } break;
1577
+ default:
1578
+ {
1579
+ fprintf(stderr, "%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
1580
+ GGML_ABORT("fatal error");
1581
+ }
1582
+ }
1583
+ } break;
1584
+ case GGML_OP_SOFT_MAX:
1585
+ {
1586
+ float scale;
1587
+ float max_bias;
1588
+
1589
+ memcpy(&scale, (float *)dst->op_params + 0, sizeof(float));
1590
+ memcpy(&max_bias, (float *)dst->op_params + 1, sizeof(float));
1591
+
1592
+ #pragma message("TODO: add ggml_vk_soft_max() F16 src1 support")
1593
+ #pragma message("ref: https://github.com/ggerganov/llama.cpp/pull/5021")
1594
+ GGML_ASSERT(!src1 || src1t == GGML_TYPE_F32);
1595
+
1596
+ #pragma message("TODO: add ALiBi support")
1597
+ #pragma message("ref: https://github.com/ggerganov/llama.cpp/pull/7192")
1598
+ GGML_ASSERT(max_bias == 0.0f);
1599
+
1600
+ ggml_vk_soft_max(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ne00, ne01, ne02, ne03, scale);
1601
+ } break;
1602
+ case GGML_OP_DIAG_MASK_INF:
1603
+ {
1604
+ const int n_past = ((int32_t *)(dst->op_params))[0];
1605
+ ggml_vk_diag_mask_inf(seq, id_src0, id_dst, off_src0, off_dst, n_past, ne00, ne01, ne02);
1606
+ } break;
1607
+ case GGML_OP_NORM:
1608
+ {
1609
+ float eps;
1610
+ memcpy(&eps, dst->op_params, sizeof(float));
1611
+ ggml_vk_norm(seq, id_src0, id_dst, off_src0, off_dst, ne00, nb01, ggml_nrows(src0), eps);
1612
+ } break;
1613
+ case GGML_OP_RMS_NORM:
1614
+ {
1615
+ GGML_ASSERT(ne00 % 4 == 0);
1616
+
1617
+ float eps;
1618
+ memcpy(&eps, dst->op_params, sizeof(float));
1619
+ ggml_vk_rms_norm(seq, id_src0, id_dst, off_src0, off_dst, ne00, nb01, ggml_nrows(src0), eps);
1620
+ } break;
1621
+ case GGML_OP_MUL_MAT:
1622
+ {
1623
+ GGML_ASSERT(ne00 == ne10);
1624
+
1625
+ GGML_ASSERT(ne12 % ne02 == 0);
1626
+ GGML_ASSERT(ne13 % ne03 == 0);
1627
+
1628
+ const uint32_t r2 = ne12/ne02;
1629
+ const uint32_t r3 = ne13/ne03;
1630
+
1631
+ if (src1t != GGML_TYPE_F32) {
1632
+ fprintf(stderr, "%s: %s: Unsupported src1 type: %u/%u\n", __func__, ggml_op_name(dst->op), src0t, src1t);
1633
+ goto not_implemented;
1634
+ }
1635
+
1636
+ if (ggml_is_transposed(src0) ||
1637
+ ggml_is_transposed(src1)) {
1638
+ fprintf(stderr, "%s: %s: matmul on tranposed tensor not supported: %u/%u\n", __func__, ggml_op_name(dst->op), src0t, src1t);
1639
+ goto not_implemented;
1640
+ }
1641
+
1642
+ switch (src0t) {
1643
+ case GGML_TYPE_F32:
1644
+ ggml_vk_mul_mat_mat_f32(
1645
+ seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst,
1646
+ ne00, ne01, ne02, nb01, nb02, ne11, ne12, nb11, nb12, nb1, nb2
1647
+ );
1648
+ break;
1649
+ case GGML_TYPE_F16:
1650
+ ggml_vk_mul_mat_f16(
1651
+ seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst,
1652
+ ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, ne13, nb10, nb11, nb12,
1653
+ ne0, ne1, r2, r3
1654
+ );
1655
+ break;
1656
+ case GGML_TYPE_Q8_0:
1657
+ ggml_vk_mul_mat_q8_0(
1658
+ seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst,
1659
+ ne00, ne01, ne02, ne10, ne11, ne12, ne13, ne0, ne1, r2, r3
1660
+ );
1661
+ break;
1662
+ case GGML_TYPE_Q4_0:
1663
+ ggml_vk_mul_mat_q4_0(
1664
+ seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst,
1665
+ ne00, ne01, ne02, ne10, ne11, ne12, ne13, ne0, ne1, r2, r3
1666
+ );
1667
+ break;
1668
+ case GGML_TYPE_Q4_1:
1669
+ ggml_vk_mul_mat_q4_1(
1670
+ seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst,
1671
+ ne00, ne01, ne02, ne10, ne11, ne12, ne13, ne0, ne1, r2, r3
1672
+ );
1673
+ break;
1674
+ case GGML_TYPE_Q4_K:
1675
+ ggml_vk_mul_mat_q4_k(
1676
+ seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst,
1677
+ ne00, ne01, ne02, ne10, ne11, ne12, ne13, ne0, ne1, ne12/ne02, ne13/ne03
1678
+ );
1679
+ break;
1680
+ case GGML_TYPE_Q6_K:
1681
+ ggml_vk_mul_mat_q6_k(
1682
+ seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst,
1683
+ ne00, ne10, ne0, ne1, ne01, ne11, ne12, ne02
1684
+ );
1685
+ break;
1686
+ default: {
1687
+ fprintf(stderr, "%s: %s: Unsupported quantization: %u/%u\n", __func__, ggml_op_name(dst->op), src0t, src1t);
1688
+ goto not_implemented;
1689
+ }
1690
+ }
1691
+
1692
+ } break;
1693
+ case GGML_OP_GET_ROWS:
1694
+ {
1695
+ if (src0t == GGML_TYPE_F32) {
1696
+ ggml_vk_get_rows_f32(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ne00, nb01, nb1, ggml_nelements(src1));
1697
+ } else if (src0t == GGML_TYPE_F16) {
1698
+ ggml_vk_get_rows_f16(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ne00, nb01, nb1, ggml_nelements(src1));
1699
+ } else if (src0t == GGML_TYPE_Q4_0) {
1700
+ ggml_vk_get_rows_q4_0(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ne00, nb01, nb1, ggml_nelements(src1));
1701
+ } else if (src0t == GGML_TYPE_Q4_1) {
1702
+ ggml_vk_get_rows_q4_1(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ne00, nb01, nb1, ggml_nelements(src1));
1703
+ } else if (src0t == GGML_TYPE_Q6_K) {
1704
+ ggml_vk_get_rows_q6_k(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ne00, nb01, nb1, ggml_nelements(src1));
1705
+ } else {
1706
+ fprintf(stderr, "%s: %s: Unsupported quantization: %u\n", __func__, ggml_op_name(dst->op), src0t);
1707
+ goto not_implemented;
1708
+ }
1709
+ } break;
1710
+ case GGML_OP_ROPE:
1711
+ {
1712
+ #pragma message("TODO: implement phi3 frequency factors support")
1713
+ #pragma message(" https://github.com/ggerganov/llama.cpp/pull/7225")
1714
+ GGML_ASSERT(dst->src[2] == nullptr && "phi3 frequency factors not implemented yet");
1715
+
1716
+ #pragma message("TODO: update rope NORM mode to match NEOX mode")
1717
+ #pragma message(" https://github.com/ggerganov/llama.cpp/pull/7634")
1718
+
1719
+ GGML_ASSERT(ne10 == ne02);
1720
+ GGML_ASSERT(src0t == dstt);
1721
+ // const int n_past = ((int32_t *) dst->op_params)[0];
1722
+ const int n_dims = ((int32_t *) dst->op_params)[1];
1723
+ const int mode = ((int32_t *) dst->op_params)[2];
1724
+ // skip 3, n_ctx used in GLM RoPE, unimplemented in Vulkan
1725
+ const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
1726
+
1727
+ float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
1728
+ memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
1729
+ memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
1730
+ memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
1731
+ memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
1732
+ memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
1733
+ memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
1734
+ ggml_vk_rope(
1735
+ seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, src0t, n_dims, mode, n_ctx_orig,
1736
+ freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow,
1737
+ ne01, ne02, ne03, nb00, nb01, nb02, nb03, ne0, nb0, nb1, nb2, nb3
1738
+ );
1739
+ } break;
1740
+ case GGML_OP_DUP:
1741
+ case GGML_OP_CPY:
1742
+ case GGML_OP_CONT:
1743
+ {
1744
+ switch (src0t) {
1745
+ case GGML_TYPE_F32:
1746
+ {
1747
+ switch (dstt) {
1748
+ case GGML_TYPE_F16: ggml_vk_cpy_f32_f16(seq, id_src0, id_dst, off_src0, off_dst, ne00, ne01, ne02, ne03, nb00, nb01, nb02, nb03, ne0, ne1, ne2, nb0, nb1, nb2, nb3); break;
1749
+ case GGML_TYPE_F32: ggml_vk_cpy_f32_f32(seq, id_src0, id_dst, off_src0, off_dst, ne00, ne01, ne02, ne03, nb00, nb01, nb02, nb03, ne0, ne1, ne2, nb0, nb1, nb2, nb3); break;
1750
+ default: goto not_implemented;
1751
+ }
1752
+ } break;
1753
+ case GGML_TYPE_F16:
1754
+ {
1755
+ switch (dstt) {
1756
+ case GGML_TYPE_F16: ggml_vk_cpy_f16_f16(seq, id_src0, id_dst, off_src0, off_dst, ne00, ne01, ne02, ne03, nb00, nb01, nb02, nb03, ne0, ne1, ne2, nb0, nb1, nb2, nb3); break;
1757
+ case GGML_TYPE_F32: ggml_vk_cpy_f16_f32(seq, id_src0, id_dst, off_src0, off_dst, ne00, ne01, ne02, ne03, nb00, nb01, nb02, nb03, ne0, ne1, ne2, nb0, nb1, nb2, nb3); break;
1758
+ default: goto not_implemented;
1759
+ } break;
1760
+ default: goto not_implemented;
1761
+ }
1762
+ }
1763
+ } break;
1764
+ default: goto not_implemented;
1765
+ }
1766
+ continue;
1767
+ not_implemented: {}
1768
+ fprintf(stderr, "%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
1769
+ //GGML_ABORT("fatal error");
1770
+ }
1771
+
1772
+ // Evaluate sequence
1773
+ if (any_commands_recorded) {
1774
+ seq.evalAsync();
1775
+ }
1776
+ }
1777
+
1778
+ // Wait for all sequences to finish
1779
+ for (auto& sequence : sequences) {
1780
+ if (sequence->isRunning())
1781
+ sequence->evalAwait();
1782
+ }
1783
+
1784
+ ggml_vk_free_descriptor_pool(ctx);
1785
+ }
1786
+
1787
+ template<>
1788
+ kp::Tensor::TensorDataTypes
1789
+ kp::TensorT<half>::dataType()
1790
+ {
1791
+ return TensorDataTypes::eFloat;
1792
+ }
1793
+
1794
+ template<>
1795
+ kp::Tensor::TensorDataTypes
1796
+ kp::TensorT<uint8_t>::dataType()
1797
+ {
1798
+ return TensorDataTypes::eUnsignedInt;
1799
+ }
1800
+
1801
+ ////////////////////////////////////////////////////////////////////////////////
1802
+
1803
+ // backend interface
1804
+
1805
+ struct ggml_backend_kompute_buffer_type_context {
1806
+ int device;
1807
+ int device_ref = 0;
1808
+ uint64_t buffer_alignment;
1809
+ uint64_t max_alloc;
1810
+ std::string name;
1811
+
1812
+ ggml_backend_kompute_buffer_type_context(int device, uint64_t buffer_alignment, uint64_t max_alloc)
1813
+ : device(device), buffer_alignment(buffer_alignment), max_alloc(max_alloc), name(ggml_kompute_format_name(device)) {}
1814
+ };
1815
+
1816
+ static void ggml_backend_kompute_device_ref(ggml_backend_buffer_type_t buft) {
1817
+ auto * ctx = static_cast<ggml_backend_kompute_buffer_type_context *>(buft->context);
1818
+
1819
+ if (!ctx->device_ref) {
1820
+ komputeManager()->initializeDevice(
1821
+ ctx->device, {}, {
1822
+ "VK_KHR_shader_float16_int8", "VK_KHR_8bit_storage",
1823
+ "VK_KHR_16bit_storage", "VK_KHR_shader_non_semantic_info"
1824
+ }
1825
+ );
1826
+ }
1827
+
1828
+ assert(ggml_vk_has_device());
1829
+ ctx->device_ref++;
1830
+ }
1831
+
1832
+ static void ggml_backend_kompute_device_unref(ggml_backend_buffer_type_t buft) {
1833
+ auto * ctx = static_cast<ggml_backend_kompute_buffer_type_context *>(buft->context);
1834
+
1835
+ assert(ctx->device_ref > 0);
1836
+
1837
+ ctx->device_ref--;
1838
+
1839
+ if (!ctx->device_ref) {
1840
+ komputeManager.destroy();
1841
+ }
1842
+ }
1843
+
1844
+ static void ggml_backend_kompute_buffer_free_buffer(ggml_backend_buffer_t buffer) {
1845
+ auto * memory = (ggml_vk_memory *)buffer->context;
1846
+ if (ggml_vk_has_device()) {
1847
+ ggml_vk_free_memory(*memory);
1848
+ }
1849
+ delete memory;
1850
+ }
1851
+
1852
+ static void * ggml_backend_kompute_buffer_get_base(ggml_backend_buffer_t buffer) {
1853
+ return ((ggml_vk_memory *)buffer->context)->data;
1854
+ }
1855
+
1856
+ static void ggml_backend_kompute_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
1857
+ GGML_UNUSED(buffer);
1858
+
1859
+ const auto res = ggml_vk_get_tensor(tensor);
1860
+ GGML_ASSERT(res);
1861
+
1862
+ memcpy((char *)tensor->data + offset, data, size);
1863
+
1864
+ komputeManager()->sequence()->eval<kp::OpTensorSyncDevice>({res});
1865
+ }
1866
+
1867
+ static void ggml_backend_kompute_buffer_get_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, size_t size) {
1868
+ GGML_UNUSED(buffer);
1869
+
1870
+ const auto res = ggml_vk_get_tensor(tensor);
1871
+ GGML_ASSERT(res);
1872
+
1873
+ komputeManager()->sequence()->eval<kp::OpTensorSyncLocal>({res});
1874
+
1875
+ memcpy(data, (const char *)tensor->data + offset, size);
1876
+ }
1877
+
1878
+ static void ggml_backend_kompute_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
1879
+ auto * memory = (ggml_vk_memory *)buffer->context;
1880
+ memset(memory->data, value, buffer->size);
1881
+
1882
+ if (memory->stagingBuffer)
1883
+ komputeManager()->sequence()->eval<kp::OpBufferSyncDevice>(memory->primaryBuffer, memory->stagingBuffer, memory->size);
1884
+ }
1885
+
1886
+ static ggml_backend_buffer_i ggml_backend_kompute_buffer_i = {
1887
+ /* .free_buffer = */ ggml_backend_kompute_buffer_free_buffer,
1888
+ /* .get_base = */ ggml_backend_kompute_buffer_get_base,
1889
+ /* .init_tensor = */ NULL,
1890
+ /* .memset_tensor = */ NULL,
1891
+ /* .set_tensor = */ ggml_backend_kompute_buffer_set_tensor,
1892
+ /* .get_tensor = */ ggml_backend_kompute_buffer_get_tensor,
1893
+ /* .cpy_tensor = */ NULL,
1894
+ /* .clear = */ ggml_backend_kompute_buffer_clear,
1895
+ /* .reset = */ NULL,
1896
+ };
1897
+
1898
+ // default buffer type
1899
+
1900
+ static const char * ggml_backend_kompute_buffer_type_get_name(ggml_backend_buffer_type_t buft) {
1901
+ auto * ctx = static_cast<ggml_backend_kompute_buffer_type_context *>(buft->context);
1902
+ return ctx->name.c_str();
1903
+ }
1904
+
1905
+ static ggml_backend_buffer_t ggml_backend_kompute_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
1906
+ ggml_backend_kompute_device_ref(buft);
1907
+ auto * ctx = new ggml_vk_memory(ggml_vk_allocate(size));
1908
+ return ggml_backend_buffer_init(buft, ggml_backend_kompute_buffer_i, ctx, size);
1909
+ }
1910
+
1911
+ static size_t ggml_backend_kompute_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
1912
+ auto * ctx = static_cast<ggml_backend_kompute_buffer_type_context *>(buft->context);
1913
+ return ctx->buffer_alignment;
1914
+ }
1915
+
1916
+ static size_t ggml_backend_vk_buffer_type_get_max_size(ggml_backend_buffer_type_t buft) {
1917
+ auto * ctx = static_cast<ggml_backend_kompute_buffer_type_context *>(buft->context);
1918
+ return ctx->max_alloc;
1919
+ }
1920
+
1921
+ static ggml_backend_buffer_type_i ggml_backend_kompute_buffer_type_interface = {
1922
+ /* .get_name = */ ggml_backend_kompute_buffer_type_get_name,
1923
+ /* .alloc_buffer = */ ggml_backend_kompute_buffer_type_alloc_buffer,
1924
+ /* .get_alignment = */ ggml_backend_kompute_buffer_type_get_alignment,
1925
+ /* .get_max_size = */ ggml_backend_vk_buffer_type_get_max_size,
1926
+ /* .get_alloc_size = */ NULL, // defaults to ggml_nbytes
1927
+ /* .is_host = */ NULL,
1928
+ };
1929
+
1930
+ ggml_backend_buffer_type_t ggml_backend_kompute_buffer_type(int device) {
1931
+ static std::mutex mutex;
1932
+ std::lock_guard<std::mutex> lock(mutex);
1933
+
1934
+ auto devices = ggml_vk_available_devices();
1935
+ int32_t device_count = (int32_t) devices.size();
1936
+ GGML_ASSERT(device < device_count);
1937
+ GGML_ASSERT(devices.size() <= GGML_KOMPUTE_MAX_DEVICES);
1938
+
1939
+ static ggml_backend_buffer_type
1940
+ ggml_backend_kompute_buffer_types[GGML_KOMPUTE_MAX_DEVICES];
1941
+
1942
+ static bool ggml_backend_kompute_buffer_type_initialized = false;
1943
+
1944
+ if (!ggml_backend_kompute_buffer_type_initialized) {
1945
+ for (int32_t i = 0; i < device_count; i++) {
1946
+ ggml_backend_kompute_buffer_types[i] = {
1947
+ /* .iface = */ ggml_backend_kompute_buffer_type_interface,
1948
+ /* .device = */ ggml_backend_reg_dev_get(ggml_backend_kompute_reg(), i),
1949
+ /* .context = */ new ggml_backend_kompute_buffer_type_context{ i, devices[i].bufferAlignment, devices[i].maxAlloc },
1950
+ };
1951
+ }
1952
+ ggml_backend_kompute_buffer_type_initialized = true;
1953
+ }
1954
+
1955
+ return &ggml_backend_kompute_buffer_types[device];
1956
+ }
1957
+
1958
+ // backend
1959
+
1960
+ static const char * ggml_backend_kompute_name(ggml_backend_t backend) {
1961
+ auto * ctx = static_cast<ggml_kompute_context *>(backend->context);
1962
+ return ctx->name.c_str();
1963
+ }
1964
+
1965
+ static void ggml_backend_kompute_free(ggml_backend_t backend) {
1966
+ auto * ctx = static_cast<ggml_kompute_context *>(backend->context);
1967
+
1968
+ assert(ctx == s_kompute_context);
1969
+ s_kompute_context = nullptr;
1970
+ if (ctx != nullptr) {
1971
+ delete ctx;
1972
+ }
1973
+
1974
+ delete backend;
1975
+ }
1976
+
1977
+ static ggml_status ggml_backend_kompute_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
1978
+ auto * ctx = static_cast<ggml_kompute_context *>(backend->context);
1979
+ ggml_vk_graph_compute(ctx, cgraph);
1980
+ return GGML_STATUS_SUCCESS;
1981
+ }
1982
+
1983
+ static struct ggml_backend_i kompute_backend_i = {
1984
+ /* .get_name = */ ggml_backend_kompute_name,
1985
+ /* .free = */ ggml_backend_kompute_free,
1986
+ /* .set_tensor_async = */ NULL,
1987
+ /* .get_tensor_async = */ NULL,
1988
+ /* .cpy_tensor_async = */ NULL,
1989
+ /* .synchronize = */ NULL,
1990
+ /* .graph_plan_create = */ NULL,
1991
+ /* .graph_plan_free = */ NULL,
1992
+ /* .graph_plan_update = */ NULL,
1993
+ /* .graph_plan_compute = */ NULL,
1994
+ /* .graph_compute = */ ggml_backend_kompute_graph_compute,
1995
+ /* .event_record = */ NULL,
1996
+ /* .event_wait = */ NULL,
1997
+ };
1998
+
1999
+ static ggml_guid_t ggml_backend_kompute_guid() {
2000
+ static ggml_guid guid = { 0x7b, 0x57, 0xdc, 0xaf, 0xde, 0x12, 0x1d, 0x49, 0xfb, 0x35, 0xfa, 0x9b, 0x18, 0x31, 0x1d, 0xca };
2001
+ return &guid;
2002
+ }
2003
+
2004
+ ggml_backend_t ggml_backend_kompute_init(int device) {
2005
+ GGML_ASSERT(s_kompute_context == nullptr);
2006
+ s_kompute_context = new ggml_kompute_context(device);
2007
+
2008
+ ggml_backend_t kompute_backend = new ggml_backend {
2009
+ /* .guid = */ ggml_backend_kompute_guid(),
2010
+ /* .interface = */ kompute_backend_i,
2011
+ /* .device = */ ggml_backend_reg_dev_get(ggml_backend_kompute_reg(), device),
2012
+ /* .context = */ s_kompute_context,
2013
+ };
2014
+
2015
+ return kompute_backend;
2016
+ }
2017
+
2018
+ bool ggml_backend_is_kompute(ggml_backend_t backend) {
2019
+ return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_kompute_guid());
2020
+ }
2021
+
2022
+ static size_t ggml_backend_kompute_get_device_count() {
2023
+ auto devices = ggml_vk_available_devices();
2024
+ return devices.size();
2025
+ }
2026
+
2027
+ static void ggml_backend_kompute_get_device_description(int device, char * description, size_t description_size) {
2028
+ auto devices = ggml_vk_available_devices();
2029
+ GGML_ASSERT((size_t) device < devices.size());
2030
+ snprintf(description, description_size, "%s", devices[device].name);
2031
+ }
2032
+
2033
+ static void ggml_backend_kompute_get_device_memory(int device, size_t * free, size_t * total) {
2034
+ auto devices = ggml_vk_available_devices();
2035
+ GGML_ASSERT((size_t) device < devices.size());
2036
+ *total = devices[device].heapSize;
2037
+ *free = devices[device].heapSize;
2038
+ }
2039
+
2040
+ //////////////////////////
2041
+
2042
+ struct ggml_backend_kompute_device_context {
2043
+ int device;
2044
+ std::string name;
2045
+ std::string description;
2046
+ };
2047
+
2048
+ static const char * ggml_backend_kompute_device_get_name(ggml_backend_dev_t dev) {
2049
+ ggml_backend_kompute_device_context * ctx = (ggml_backend_kompute_device_context *)dev->context;
2050
+ return ctx->name.c_str();
2051
+ }
2052
+
2053
+ static const char * ggml_backend_kompute_device_get_description(ggml_backend_dev_t dev) {
2054
+ ggml_backend_kompute_device_context * ctx = (ggml_backend_kompute_device_context *)dev->context;
2055
+ return ctx->description.c_str();
2056
+ }
2057
+
2058
+ static void ggml_backend_kompute_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
2059
+ ggml_backend_kompute_device_context * ctx = (ggml_backend_kompute_device_context *)dev->context;
2060
+ ggml_backend_kompute_get_device_memory(ctx->device, free, total);
2061
+ }
2062
+
2063
+ static ggml_backend_buffer_type_t ggml_backend_kompute_device_get_buffer_type(ggml_backend_dev_t dev) {
2064
+ ggml_backend_kompute_device_context * ctx = (ggml_backend_kompute_device_context *)dev->context;
2065
+ return ggml_backend_kompute_buffer_type(ctx->device);
2066
+ }
2067
+
2068
+ static bool ggml_backend_kompute_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {
2069
+ if (buft->iface.get_name != ggml_backend_kompute_buffer_type_get_name) {
2070
+ return false;
2071
+ }
2072
+
2073
+ ggml_backend_kompute_device_context * ctx = (ggml_backend_kompute_device_context *)dev->context;
2074
+ ggml_backend_kompute_buffer_type_context * buft_ctx = (ggml_backend_kompute_buffer_type_context *)buft->context;
2075
+
2076
+ return buft_ctx->device == ctx->device;
2077
+ }
2078
+
2079
+ static enum ggml_backend_dev_type ggml_backend_kompute_device_get_type(ggml_backend_dev_t dev) {
2080
+ GGML_UNUSED(dev);
2081
+ return GGML_BACKEND_DEVICE_TYPE_GPU;
2082
+ }
2083
+
2084
+ static void ggml_backend_kompute_device_get_props(ggml_backend_dev_t dev, struct ggml_backend_dev_props * props) {
2085
+ props->name = ggml_backend_kompute_device_get_name(dev);
2086
+ props->description = ggml_backend_kompute_device_get_description(dev);
2087
+ props->type = ggml_backend_kompute_device_get_type(dev);
2088
+ ggml_backend_kompute_device_get_memory(dev, &props->memory_free, &props->memory_total);
2089
+ props->caps = {
2090
+ /* async = */ false,
2091
+ /* host_buffer = */ false,
2092
+ /* .buffer_from_host_ptr = */ false,
2093
+ /* events = */ false,
2094
+ };
2095
+ }
2096
+
2097
+ static ggml_backend_t ggml_backend_kompute_device_init(ggml_backend_dev_t dev, const char * params) {
2098
+ GGML_UNUSED(params);
2099
+ ggml_backend_kompute_device_context * ctx = (ggml_backend_kompute_device_context *)dev->context;
2100
+ return ggml_backend_kompute_init(ctx->device);
2101
+ }
2102
+
2103
+ static bool ggml_backend_kompute_device_offload_op(ggml_backend_dev_t dev, const ggml_tensor * op) {
2104
+ const int min_batch_size = 32;
2105
+
2106
+ return (op->ne[1] >= min_batch_size && op->op != GGML_OP_GET_ROWS) ||
2107
+ (op->ne[2] >= min_batch_size && op->op == GGML_OP_MUL_MAT_ID);
2108
+
2109
+ GGML_UNUSED(dev);
2110
+ }
2111
+
2112
+ static const struct ggml_backend_device_i ggml_backend_kompute_device_i = {
2113
+ /* .get_name = */ ggml_backend_kompute_device_get_name,
2114
+ /* .get_description = */ ggml_backend_kompute_device_get_description,
2115
+ /* .get_memory = */ ggml_backend_kompute_device_get_memory,
2116
+ /* .get_type = */ ggml_backend_kompute_device_get_type,
2117
+ /* .get_props = */ ggml_backend_kompute_device_get_props,
2118
+ /* .init_backend = */ ggml_backend_kompute_device_init,
2119
+ /* .get_buffer_type = */ ggml_backend_kompute_device_get_buffer_type,
2120
+ /* .get_host_buffer_type = */ NULL,
2121
+ /* .buffer_from_host_ptr = */ NULL,
2122
+ /* .supports_op = */ ggml_backend_kompute_device_supports_op,
2123
+ /* .supports_buft = */ ggml_backend_kompute_device_supports_buft,
2124
+ /* .offload_op = */ ggml_backend_kompute_device_offload_op,
2125
+ /* .event_new = */ NULL,
2126
+ /* .event_free = */ NULL,
2127
+ /* .event_synchronize = */ NULL,
2128
+ };
2129
+
2130
+ static const char * ggml_backend_kompute_reg_get_name(ggml_backend_reg_t reg) {
2131
+ GGML_UNUSED(reg);
2132
+ return "Kompute";
2133
+ }
2134
+
2135
+ static size_t ggml_backend_kompute_reg_get_device_count(ggml_backend_reg_t reg) {
2136
+ GGML_UNUSED(reg);
2137
+ return ggml_backend_kompute_get_device_count();
2138
+ }
2139
+
2140
+ static ggml_backend_dev_t ggml_backend_kompute_reg_get_device(ggml_backend_reg_t reg, size_t device) {
2141
+ static std::vector<ggml_backend_dev_t> devices;
2142
+
2143
+ static bool initialized = false;
2144
+
2145
+ {
2146
+ static std::mutex mutex;
2147
+ std::lock_guard<std::mutex> lock(mutex);
2148
+ if (!initialized) {
2149
+ for (size_t i = 0; i < ggml_backend_kompute_get_device_count(); i++) {
2150
+ ggml_backend_kompute_device_context * ctx = new ggml_backend_kompute_device_context;
2151
+ char desc[256];
2152
+ ggml_backend_kompute_get_device_description(i, desc, sizeof(desc));
2153
+ ctx->device = i;
2154
+ ctx->name = "Kompute" + std::to_string(i);
2155
+ ctx->description = desc;
2156
+ devices.push_back(new ggml_backend_device {
2157
+ /* .iface = */ ggml_backend_kompute_device_i,
2158
+ /* .reg = */ reg,
2159
+ /* .context = */ ctx,
2160
+ });
2161
+ }
2162
+ initialized = true;
2163
+ }
2164
+ }
2165
+
2166
+ GGML_ASSERT(device < devices.size());
2167
+ return devices[device];
2168
+ }
2169
+
2170
+ static const struct ggml_backend_reg_i ggml_backend_kompute_reg_i = {
2171
+ /* .get_name = */ ggml_backend_kompute_reg_get_name,
2172
+ /* .get_device_count = */ ggml_backend_kompute_reg_get_device_count,
2173
+ /* .get_device = */ ggml_backend_kompute_reg_get_device,
2174
+ /* .get_proc_address = */ NULL,
2175
+ };
2176
+
2177
+ ggml_backend_reg_t ggml_backend_kompute_reg() {
2178
+ static ggml_backend_reg reg = {
2179
+ /* .iface = */ ggml_backend_kompute_reg_i,
2180
+ /* .context = */ nullptr,
2181
+ };
2182
+
2183
+ return &reg;
2184
+ }