waltgrace commited on
Commit
6123ae4
·
verified ·
1 Parent(s): a8c1934

Add common/common.cpp

Browse files
Files changed (1) hide show
  1. common/common.cpp +1940 -0
common/common.cpp ADDED
@@ -0,0 +1,1940 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "ggml.h"
2
+ #include "gguf.h"
3
+
4
+ #include "common.h"
5
+ #include "log.h"
6
+ #include "llama.h"
7
+ #include "../src/llama-expert-cache-ctx.h"
8
+ #include "../src/llama-model.h"
9
+ #include "sampling.h"
10
+ #include "unicode.h"
11
+
12
+ #include <algorithm>
13
+ #include <cinttypes>
14
+ #include <climits>
15
+ #include <cmath>
16
+ #include <chrono>
17
+ #include <cstdarg>
18
+ #include <cstring>
19
+ #include <ctime>
20
+ #include <filesystem>
21
+ #include <fstream>
22
+ #include <iostream>
23
+ #include <iterator>
24
+ #include <regex>
25
+ #include <sstream>
26
+ #include <string>
27
+ #include <thread>
28
+ #include <unordered_set>
29
+ #include <vector>
30
+
31
+ #if defined(__APPLE__) && defined(__MACH__)
32
+ #include <sys/types.h>
33
+ #include <sys/sysctl.h>
34
+ #endif
35
+
36
+ #if defined(_WIN32)
37
+ #define WIN32_LEAN_AND_MEAN
38
+ #ifndef NOMINMAX
39
+ # define NOMINMAX
40
+ #endif
41
+ #include <locale>
42
+ #include <windows.h>
43
+ #include <string.h>
44
+ #include <fcntl.h>
45
+ #include <io.h>
46
+ #else
47
+ #include <sys/ioctl.h>
48
+ #include <sys/stat.h>
49
+ #include <unistd.h>
50
+ #endif
51
+
52
+ #if defined(__linux__)
53
+ #include <sys/types.h>
54
+ #include <pwd.h>
55
+ #endif
56
+
57
+ #if defined(_MSC_VER)
58
+ #pragma warning(disable: 4244 4267) // possible loss of data
59
+ #endif
60
+
61
+ common_time_meas::common_time_meas(int64_t & t_acc, bool disable) : t_start_us(disable ? -1 : ggml_time_us()), t_acc(t_acc) {}
62
+
63
+ common_time_meas::~common_time_meas() {
64
+ if (t_start_us >= 0) {
65
+ t_acc += ggml_time_us() - t_start_us;
66
+ }
67
+ }
68
+
69
+ //
70
+ // CPU utils
71
+ //
72
+
73
+ int32_t cpu_get_num_physical_cores() {
74
+ #ifdef __linux__
75
+ // enumerate the set of thread siblings, num entries is num cores
76
+ std::unordered_set<std::string> siblings;
77
+ for (uint32_t cpu=0; cpu < UINT32_MAX; ++cpu) {
78
+ std::ifstream thread_siblings("/sys/devices/system/cpu/cpu"
79
+ + std::to_string(cpu) + "/topology/thread_siblings");
80
+ if (!thread_siblings.is_open()) {
81
+ break; // no more cpus
82
+ }
83
+ std::string line;
84
+ if (std::getline(thread_siblings, line)) {
85
+ siblings.insert(line);
86
+ }
87
+ }
88
+ if (!siblings.empty()) {
89
+ return static_cast<int32_t>(siblings.size());
90
+ }
91
+ #elif defined(__APPLE__) && defined(__MACH__)
92
+ int32_t num_physical_cores;
93
+ size_t len = sizeof(num_physical_cores);
94
+ int result = sysctlbyname("hw.perflevel0.physicalcpu", &num_physical_cores, &len, NULL, 0);
95
+ if (result == 0) {
96
+ return num_physical_cores;
97
+ }
98
+ result = sysctlbyname("hw.physicalcpu", &num_physical_cores, &len, NULL, 0);
99
+ if (result == 0) {
100
+ return num_physical_cores;
101
+ }
102
+ #elif defined(_WIN32) && (_WIN32_WINNT >= 0x0601) && !defined(__MINGW64__) // windows 7 and later
103
+ // TODO: windows + arm64 + mingw64
104
+ unsigned int n_threads_win = std::thread::hardware_concurrency();
105
+ unsigned int default_threads = n_threads_win > 0 ? (n_threads_win <= 4 ? n_threads_win : n_threads_win / 2) : 4;
106
+
107
+ DWORD buffer_size = 0;
108
+ if (!GetLogicalProcessorInformationEx(RelationProcessorCore, nullptr, &buffer_size)) {
109
+ if (GetLastError() != ERROR_INSUFFICIENT_BUFFER) {
110
+ return default_threads;
111
+ }
112
+ }
113
+
114
+ std::vector<char> buffer(buffer_size);
115
+ if (!GetLogicalProcessorInformationEx(RelationProcessorCore, reinterpret_cast<PSYSTEM_LOGICAL_PROCESSOR_INFORMATION_EX>(buffer.data()), &buffer_size)) {
116
+ return default_threads;
117
+ }
118
+
119
+ int32_t num_physical_cores = 0;
120
+ PSYSTEM_LOGICAL_PROCESSOR_INFORMATION_EX info = reinterpret_cast<PSYSTEM_LOGICAL_PROCESSOR_INFORMATION_EX>(buffer.data());
121
+ while (buffer_size > 0) {
122
+ if (info->Relationship == RelationProcessorCore) {
123
+ num_physical_cores += info->Processor.GroupCount;
124
+ }
125
+ buffer_size -= info->Size;
126
+ info = reinterpret_cast<PSYSTEM_LOGICAL_PROCESSOR_INFORMATION_EX>(reinterpret_cast<char*>(info) + info->Size);
127
+ }
128
+
129
+ return num_physical_cores > 0 ? num_physical_cores : default_threads;
130
+ #endif
131
+ unsigned int n_threads = std::thread::hardware_concurrency();
132
+ return n_threads > 0 ? (n_threads <= 4 ? n_threads : n_threads / 2) : 4;
133
+ }
134
+
135
+ #if defined(__x86_64__) && defined(__linux__) && !defined(__ANDROID__)
136
+ #include <pthread.h>
137
+
138
+ static void cpuid(unsigned leaf, unsigned subleaf,
139
+ unsigned *eax, unsigned *ebx, unsigned *ecx, unsigned *edx) {
140
+ __asm__("movq\t%%rbx,%%rsi\n\t"
141
+ "cpuid\n\t"
142
+ "xchgq\t%%rbx,%%rsi"
143
+ : "=a"(*eax), "=S"(*ebx), "=c"(*ecx), "=d"(*edx)
144
+ : "0"(leaf), "2"(subleaf));
145
+ }
146
+
147
+ static int pin_cpu(int cpu) {
148
+ cpu_set_t mask;
149
+ CPU_ZERO(&mask);
150
+ CPU_SET(cpu, &mask);
151
+ return pthread_setaffinity_np(pthread_self(), sizeof(mask), &mask);
152
+ }
153
+
154
+ static bool is_hybrid_cpu(void) {
155
+ unsigned eax, ebx, ecx, edx;
156
+ cpuid(7, 0, &eax, &ebx, &ecx, &edx);
157
+ return !!(edx & (1u << 15));
158
+ }
159
+
160
+ static bool is_running_on_efficiency_core(void) {
161
+ unsigned eax, ebx, ecx, edx;
162
+ cpuid(0x1a, 0, &eax, &ebx, &ecx, &edx);
163
+ int intel_atom = 0x20;
164
+ int core_type = (eax & 0xff000000u) >> 24;
165
+ return core_type == intel_atom;
166
+ }
167
+
168
+ static int cpu_count_math_cpus(int n_cpu) {
169
+ int result = 0;
170
+ for (int cpu = 0; cpu < n_cpu; ++cpu) {
171
+ if (pin_cpu(cpu)) {
172
+ return -1;
173
+ }
174
+ if (is_running_on_efficiency_core()) {
175
+ continue; // efficiency cores harm lockstep threading
176
+ }
177
+ ++cpu; // hyperthreading isn't useful for linear algebra
178
+ ++result;
179
+ }
180
+ return result;
181
+ }
182
+
183
+ #endif // __x86_64__ && __linux__
184
+
185
+ /**
186
+ * Returns number of CPUs on system that are useful for math.
187
+ */
188
+ int32_t cpu_get_num_math() {
189
+ #if defined(__x86_64__) && defined(__linux__) && !defined(__ANDROID__)
190
+ int n_cpu = sysconf(_SC_NPROCESSORS_ONLN);
191
+ if (n_cpu < 1) {
192
+ return cpu_get_num_physical_cores();
193
+ }
194
+ if (is_hybrid_cpu()) {
195
+ cpu_set_t affinity;
196
+ if (!pthread_getaffinity_np(pthread_self(), sizeof(affinity), &affinity)) {
197
+ int result = cpu_count_math_cpus(n_cpu);
198
+ pthread_setaffinity_np(pthread_self(), sizeof(affinity), &affinity);
199
+ if (result > 0) {
200
+ return result;
201
+ }
202
+ }
203
+ }
204
+ #endif
205
+ return cpu_get_num_physical_cores();
206
+ }
207
+
208
+ // Helper for setting process priority
209
+
210
+ #if defined(_WIN32)
211
+
212
+ bool set_process_priority(enum ggml_sched_priority prio) {
213
+ if (prio == GGML_SCHED_PRIO_NORMAL) {
214
+ return true;
215
+ }
216
+
217
+ DWORD p = NORMAL_PRIORITY_CLASS;
218
+ switch (prio) {
219
+ case GGML_SCHED_PRIO_LOW: p = BELOW_NORMAL_PRIORITY_CLASS; break;
220
+ case GGML_SCHED_PRIO_NORMAL: p = NORMAL_PRIORITY_CLASS; break;
221
+ case GGML_SCHED_PRIO_MEDIUM: p = ABOVE_NORMAL_PRIORITY_CLASS; break;
222
+ case GGML_SCHED_PRIO_HIGH: p = HIGH_PRIORITY_CLASS; break;
223
+ case GGML_SCHED_PRIO_REALTIME: p = REALTIME_PRIORITY_CLASS; break;
224
+ }
225
+
226
+ if (!SetPriorityClass(GetCurrentProcess(), p)) {
227
+ LOG_WRN("failed to set process priority class %d : (%d)\n", prio, (int) GetLastError());
228
+ return false;
229
+ }
230
+
231
+ return true;
232
+ }
233
+
234
+ #else // MacOS and POSIX
235
+ #include <sys/types.h>
236
+ #include <sys/resource.h>
237
+
238
+ bool set_process_priority(enum ggml_sched_priority prio) {
239
+ if (prio == GGML_SCHED_PRIO_NORMAL) {
240
+ return true;
241
+ }
242
+
243
+ int p = 0;
244
+ switch (prio) {
245
+ case GGML_SCHED_PRIO_LOW: p = 5; break;
246
+ case GGML_SCHED_PRIO_NORMAL: p = 0; break;
247
+ case GGML_SCHED_PRIO_MEDIUM: p = -5; break;
248
+ case GGML_SCHED_PRIO_HIGH: p = -10; break;
249
+ case GGML_SCHED_PRIO_REALTIME: p = -20; break;
250
+ }
251
+
252
+ if (setpriority(PRIO_PROCESS, 0, p) != 0) {
253
+ LOG_WRN("failed to set process priority %d : %s (%d)\n", prio, strerror(errno), errno);
254
+ return false;
255
+ }
256
+ return true;
257
+ }
258
+
259
+ #endif
260
+
261
+ //
262
+ // CLI argument parsing
263
+ //
264
+
265
+
266
+ void postprocess_cpu_params(cpu_params& cpuparams, const cpu_params* role_model) {
267
+ int32_t n_set = 0;
268
+
269
+ if (cpuparams.n_threads < 0) {
270
+ // Assuming everything about cpuparams is invalid
271
+ if (role_model != nullptr) {
272
+ cpuparams = *role_model;
273
+ } else {
274
+ cpuparams.n_threads = cpu_get_num_math();
275
+ }
276
+ }
277
+
278
+ for (int32_t i = 0; i < GGML_MAX_N_THREADS; i++) {
279
+ if (cpuparams.cpumask[i]) {
280
+ n_set++;
281
+ }
282
+ }
283
+
284
+ if (n_set && n_set < cpuparams.n_threads) {
285
+ // Not enough set bits, may experience performance issues.
286
+ LOG_WRN("Not enough set bits in CPU mask (%d) to satisfy requested thread count: %d\n", n_set, cpuparams.n_threads);
287
+ }
288
+ }
289
+
290
+ bool parse_cpu_range(const std::string & range, bool (&boolmask)[GGML_MAX_N_THREADS]) {
291
+ size_t dash_loc = range.find('-');
292
+ if (dash_loc == std::string::npos) {
293
+ LOG_ERR("Format of CPU range is invalid! Expected [<start>]-[<end>].\n");
294
+ return false;
295
+ }
296
+
297
+ size_t start_i;
298
+ size_t end_i;
299
+
300
+ if (dash_loc == 0) {
301
+ start_i = 0;
302
+ } else {
303
+ start_i = std::stoull(range.substr(0, dash_loc));
304
+ if (start_i >= GGML_MAX_N_THREADS) {
305
+ LOG_ERR("Start index out of bounds!\n");
306
+ return false;
307
+ }
308
+ }
309
+
310
+ if (dash_loc == range.length() - 1) {
311
+ end_i = GGML_MAX_N_THREADS - 1;
312
+ } else {
313
+ end_i = std::stoull(range.substr(dash_loc + 1));
314
+ if (end_i >= GGML_MAX_N_THREADS) {
315
+ LOG_ERR("End index out of bounds!\n");
316
+ return false;
317
+ }
318
+ }
319
+
320
+ for (size_t i = start_i; i <= end_i; i++) {
321
+ boolmask[i] = true;
322
+ }
323
+
324
+ return true;
325
+ }
326
+
327
+ bool parse_cpu_mask(const std::string & mask, bool (&boolmask)[GGML_MAX_N_THREADS]) {
328
+ // Discard potential 0x prefix
329
+ size_t start_i = 0;
330
+ if (mask.length() >= 2 && mask.substr(0, 2) == "0x") {
331
+ start_i = 2;
332
+ }
333
+
334
+ size_t num_digits = mask.length() - start_i;
335
+ if (num_digits > 128) num_digits = 128;
336
+
337
+ size_t end_i = num_digits + start_i;
338
+
339
+ for (size_t i = start_i, n = (num_digits*4 - 1); i < end_i; i++, n-=4) {
340
+ char c = mask.at(i);
341
+ int8_t id = c;
342
+
343
+ if ((c >= '0' && c <= '9')) {
344
+ id -= '0';
345
+ } else if (c >= 'a' && c <= 'f') {
346
+ id -= 'a' - 10;
347
+ } else if (c >= 'A' && c <= 'F') {
348
+ id -= 'A' - 10;
349
+ } else {
350
+ LOG_ERR("Invalid hex character '%c' at position %d\n", c, int32_t(i));
351
+ return false;
352
+ }
353
+
354
+ boolmask[ n ] = boolmask[ n ] || ((id & 8) != 0);
355
+ boolmask[n - 1] = boolmask[n - 1] || ((id & 4) != 0);
356
+ boolmask[n - 2] = boolmask[n - 2] || ((id & 2) != 0);
357
+ boolmask[n - 3] = boolmask[n - 3] || ((id & 1) != 0);
358
+ }
359
+
360
+ return true;
361
+ }
362
+
363
+ void common_init() {
364
+ llama_log_set(common_log_default_callback, NULL);
365
+
366
+ #ifdef NDEBUG
367
+ const char * build_type = "";
368
+ #else
369
+ const char * build_type = " (debug)";
370
+ #endif
371
+
372
+ LOG_INF("build: %d (%s) with %s for %s%s\n", LLAMA_BUILD_NUMBER, LLAMA_COMMIT, LLAMA_COMPILER, LLAMA_BUILD_TARGET, build_type);
373
+ }
374
+
375
+ std::string common_params_get_system_info(const common_params & params) {
376
+ std::ostringstream os;
377
+
378
+ os << "system_info: n_threads = " << params.cpuparams.n_threads;
379
+ if (params.cpuparams_batch.n_threads != -1) {
380
+ os << " (n_threads_batch = " << params.cpuparams_batch.n_threads << ")";
381
+ }
382
+ #if defined(_WIN32) && (_WIN32_WINNT >= 0x0601) && !defined(__MINGW64__) // windows 7 and later
383
+ // TODO: windows + arm64 + mingw64
384
+ DWORD logicalProcessorCount = GetActiveProcessorCount(ALL_PROCESSOR_GROUPS);
385
+ os << " / " << logicalProcessorCount << " | " << llama_print_system_info();
386
+ #else
387
+ os << " / " << std::thread::hardware_concurrency() << " | " << llama_print_system_info();
388
+ #endif
389
+
390
+ return os.str();
391
+ }
392
+
393
+ //
394
+ // String utils
395
+ //
396
+
397
+ std::string string_format(const char * fmt, ...) {
398
+ va_list ap;
399
+ va_list ap2;
400
+ va_start(ap, fmt);
401
+ va_copy(ap2, ap);
402
+ int size = vsnprintf(NULL, 0, fmt, ap);
403
+ GGML_ASSERT(size >= 0 && size < INT_MAX); // NOLINT
404
+ std::vector<char> buf(size + 1);
405
+ int size2 = vsnprintf(buf.data(), size + 1, fmt, ap2);
406
+ GGML_ASSERT(size2 == size);
407
+ va_end(ap2);
408
+ va_end(ap);
409
+ return std::string(buf.data(), size);
410
+ }
411
+
412
+ std::string string_strip(const std::string & str) {
413
+ size_t start = 0;
414
+ size_t end = str.size();
415
+ while (start < end && std::isspace(str[start])) {
416
+ start++;
417
+ }
418
+ while (end > start && std::isspace(str[end - 1])) {
419
+ end--;
420
+ }
421
+ return str.substr(start, end - start);
422
+ }
423
+
424
+ std::string string_get_sortable_timestamp() {
425
+ using clock = std::chrono::system_clock;
426
+
427
+ const clock::time_point current_time = clock::now();
428
+ const time_t as_time_t = clock::to_time_t(current_time);
429
+ char timestamp_no_ns[100];
430
+ std::strftime(timestamp_no_ns, 100, "%Y_%m_%d-%H_%M_%S", std::localtime(&as_time_t));
431
+
432
+ const int64_t ns = std::chrono::duration_cast<std::chrono::nanoseconds>(
433
+ current_time.time_since_epoch() % 1000000000).count();
434
+ char timestamp_ns[11];
435
+ snprintf(timestamp_ns, 11, "%09" PRId64, ns);
436
+
437
+ return std::string(timestamp_no_ns) + "." + std::string(timestamp_ns);
438
+ }
439
+
440
+ void string_replace_all(std::string & s, const std::string & search, const std::string & replace) {
441
+ if (search.empty()) {
442
+ return;
443
+ }
444
+ std::string builder;
445
+ builder.reserve(s.length());
446
+ size_t pos = 0;
447
+ size_t last_pos = 0;
448
+ while ((pos = s.find(search, last_pos)) != std::string::npos) {
449
+ builder.append(s, last_pos, pos - last_pos);
450
+ builder.append(replace);
451
+ last_pos = pos + search.length();
452
+ }
453
+ builder.append(s, last_pos, std::string::npos);
454
+ s = std::move(builder);
455
+ }
456
+
457
+ std::string regex_escape(const std::string & s) {
458
+ static const std::regex special_chars("[.^$|()*+?\\[\\]{}\\\\]");
459
+ return std::regex_replace(s, special_chars, "\\$&");
460
+ }
461
+
462
+ std::string string_join(const std::vector<std::string> & values, const std::string & separator) {
463
+ std::ostringstream result;
464
+ for (size_t i = 0; i < values.size(); ++i) {
465
+ if (i > 0) {
466
+ result << separator;
467
+ }
468
+ result << values[i];
469
+ }
470
+ return result.str();
471
+ }
472
+
473
+ std::vector<std::string> string_split(const std::string & str, const std::string & delimiter) {
474
+ std::vector<std::string> parts;
475
+ size_t start = 0;
476
+ size_t end = str.find(delimiter);
477
+
478
+ while (end != std::string::npos) {
479
+ parts.push_back(str.substr(start, end - start));
480
+ start = end + delimiter.length();
481
+ end = str.find(delimiter, start);
482
+ }
483
+
484
+ parts.push_back(str.substr(start));
485
+
486
+ return parts;
487
+ }
488
+
489
+ std::string string_repeat(const std::string & str, size_t n) {
490
+ if (n == 0) {
491
+ return "";
492
+ }
493
+
494
+ std::string result;
495
+ result.reserve(str.length() * n);
496
+
497
+ for (size_t i = 0; i < n; ++i) {
498
+ result += str;
499
+ }
500
+
501
+ return result;
502
+ }
503
+
504
+ std::string string_from(bool value) {
505
+ return value ? "true" : "false";
506
+ }
507
+
508
+ std::string string_from(const std::vector<int> & values) {
509
+ std::stringstream buf;
510
+
511
+ buf << "[ ";
512
+ bool first = true;
513
+ for (auto e : values) {
514
+ if (first) {
515
+ first = false;
516
+ } else {
517
+ buf << ", ";
518
+ }
519
+ buf << std::to_string(e);
520
+ }
521
+ buf << " ]";
522
+
523
+ return buf.str();
524
+ }
525
+
526
+ std::string string_from(const struct llama_context * ctx, const std::vector<llama_token> & tokens) {
527
+ std::stringstream buf;
528
+
529
+ buf << "[ ";
530
+
531
+ bool first = true;
532
+ for (const auto & token : tokens) {
533
+ if (!first) {
534
+ buf << ", ";
535
+ } else {
536
+ first = false;
537
+ }
538
+
539
+ auto detokenized = common_token_to_piece(ctx, token);
540
+
541
+ buf << "'" << detokenized << "'"
542
+ << ":" << std::to_string(token);
543
+ }
544
+
545
+ buf << " ]";
546
+
547
+ return buf.str();
548
+ }
549
+
550
+ std::string string_from(const struct llama_context * ctx, const struct llama_batch & batch) {
551
+ std::stringstream buf;
552
+
553
+ buf << "[ ";
554
+
555
+ bool first = true;
556
+ for (int i = 0; i < batch.n_tokens; ++i) {
557
+ if (!first) {
558
+ buf << ", ";
559
+ } else {
560
+ first = false;
561
+ }
562
+
563
+ auto detokenized = common_token_to_piece(ctx, batch.token[i]);
564
+
565
+ buf << "\n" << std::to_string(i)
566
+ << ", token '" << detokenized << "'"
567
+ << ", pos " << std::to_string(batch.pos[i])
568
+ << ", n_seq_id " << std::to_string(batch.n_seq_id[i])
569
+ << ", seq_id " << std::to_string(batch.seq_id[i][0])
570
+ << ", logits " << std::to_string(batch.logits[i]);
571
+ }
572
+
573
+ buf << " ]";
574
+
575
+ return buf.str();
576
+ }
577
+
578
+ void string_process_escapes(std::string & input) {
579
+ std::size_t input_len = input.length();
580
+ std::size_t output_idx = 0;
581
+
582
+ for (std::size_t input_idx = 0; input_idx < input_len; ++input_idx) {
583
+ if (input[input_idx] == '\\' && input_idx + 1 < input_len) {
584
+ switch (input[++input_idx]) {
585
+ case 'n': input[output_idx++] = '\n'; break;
586
+ case 'r': input[output_idx++] = '\r'; break;
587
+ case 't': input[output_idx++] = '\t'; break;
588
+ case '\'': input[output_idx++] = '\''; break;
589
+ case '\"': input[output_idx++] = '\"'; break;
590
+ case '\\': input[output_idx++] = '\\'; break;
591
+ case 'x':
592
+ // Handle \x12, etc
593
+ if (input_idx + 2 < input_len) {
594
+ const char x[3] = { input[input_idx + 1], input[input_idx + 2], 0 };
595
+ char *err_p = nullptr;
596
+ const long val = std::strtol(x, &err_p, 16);
597
+ if (err_p == x + 2) {
598
+ input_idx += 2;
599
+ input[output_idx++] = char(val);
600
+ break;
601
+ }
602
+ }
603
+ // fall through
604
+ default: input[output_idx++] = '\\';
605
+ input[output_idx++] = input[input_idx]; break;
606
+ }
607
+ } else {
608
+ input[output_idx++] = input[input_idx];
609
+ }
610
+ }
611
+
612
+ input.resize(output_idx);
613
+ }
614
+
615
+ bool string_parse_kv_override(const char * data, std::vector<llama_model_kv_override> & overrides) {
616
+ const char * sep = strchr(data, '=');
617
+ if (sep == nullptr || sep - data >= 128) {
618
+ LOG_ERR("%s: malformed KV override '%s'\n", __func__, data);
619
+ return false;
620
+ }
621
+ llama_model_kv_override kvo;
622
+ std::strncpy(kvo.key, data, sep - data);
623
+ kvo.key[sep - data] = 0;
624
+ sep++;
625
+ if (strncmp(sep, "int:", 4) == 0) {
626
+ sep += 4;
627
+ kvo.tag = LLAMA_KV_OVERRIDE_TYPE_INT;
628
+ kvo.val_i64 = std::atol(sep);
629
+ } else if (strncmp(sep, "float:", 6) == 0) {
630
+ sep += 6;
631
+ kvo.tag = LLAMA_KV_OVERRIDE_TYPE_FLOAT;
632
+ kvo.val_f64 = std::atof(sep);
633
+ } else if (strncmp(sep, "bool:", 5) == 0) {
634
+ sep += 5;
635
+ kvo.tag = LLAMA_KV_OVERRIDE_TYPE_BOOL;
636
+ if (std::strcmp(sep, "true") == 0) {
637
+ kvo.val_bool = true;
638
+ } else if (std::strcmp(sep, "false") == 0) {
639
+ kvo.val_bool = false;
640
+ } else {
641
+ LOG_ERR("%s: invalid boolean value for KV override '%s'\n", __func__, data);
642
+ return false;
643
+ }
644
+ } else if (strncmp(sep, "str:", 4) == 0) {
645
+ sep += 4;
646
+ kvo.tag = LLAMA_KV_OVERRIDE_TYPE_STR;
647
+ if (strlen(sep) > 127) {
648
+ LOG_ERR("%s: malformed KV override '%s', value cannot exceed 127 chars\n", __func__, data);
649
+ return false;
650
+ }
651
+ strncpy(kvo.val_str, sep, 127);
652
+ kvo.val_str[127] = '\0';
653
+ } else {
654
+ LOG_ERR("%s: invalid type for KV override '%s'\n", __func__, data);
655
+ return false;
656
+ }
657
+ overrides.emplace_back(std::move(kvo));
658
+ return true;
659
+ }
660
+
661
+ static inline bool glob_class_match(const char c, const char * pattern, const char * class_end) {
662
+ const char * class_start = pattern;
663
+ bool negated = false;
664
+
665
+ if (*class_start == '!') {
666
+ negated = true;
667
+ class_start++;
668
+ }
669
+
670
+ // If first character after negation is ']' or '-', treat it as literal
671
+ if (*class_start == ']' || *class_start == '-') {
672
+ if (class_start < class_end && *class_start == c) {
673
+ return !negated;
674
+ }
675
+ class_start++;
676
+ }
677
+
678
+ bool matched = false;
679
+
680
+ while (class_start < class_end) {
681
+ if (class_start + 2 < class_end && class_start[1] == '-' && class_start[2] != ']') {
682
+ char start_char = *class_start;
683
+ char end_char = class_start[2];
684
+ if (c >= start_char && c <= end_char) {
685
+ matched = true;
686
+ break;
687
+ }
688
+ class_start += 3;
689
+ } else {
690
+ if (*class_start == c) {
691
+ matched = true;
692
+ break;
693
+ }
694
+ class_start++;
695
+ }
696
+ }
697
+
698
+ return negated ? !matched : matched;
699
+ }
700
+
701
+ // simple glob: * matches non-/ chars, ** matches anything including /, [] matches character class
702
+ static inline bool glob_match(const char * pattern, const char * str) {
703
+ if (*pattern == '\0') {
704
+ return *str == '\0';
705
+ }
706
+ if (pattern[0] == '*' && pattern[1] == '*') {
707
+ const char * p = pattern + 2;
708
+ if (glob_match(p, str)) return true;
709
+ if (*str != '\0') return glob_match(pattern, str + 1);
710
+ return false;
711
+ }
712
+ if (*pattern == '*') {
713
+ const char * p = pattern + 1;
714
+ for (; *str != '\0' && *str != '/'; str++) {
715
+ if (glob_match(p, str)) return true;
716
+ }
717
+ return glob_match(p, str);
718
+ }
719
+ if (*pattern == '?' && *str != '\0' && *str != '/') {
720
+ return glob_match(pattern + 1, str + 1);
721
+ }
722
+ if (*pattern == '[') {
723
+ const char * class_end = pattern + 1;
724
+ // If first character after '[' is ']' or '-', treat it as literal
725
+ if (*class_end == ']' || *class_end == '-') {
726
+ class_end++;
727
+ }
728
+ while (*class_end != '\0' && *class_end != ']') {
729
+ class_end++;
730
+ }
731
+ if (*class_end == ']') {
732
+ if (*str == '\0') return false;
733
+ bool matched = glob_class_match(*str, pattern + 1, class_end);
734
+ return matched && glob_match(class_end + 1, str + 1);
735
+ } else {
736
+ if (*str == '[') {
737
+ return glob_match(pattern + 1, str + 1);
738
+ }
739
+ return false;
740
+ }
741
+ }
742
+ if (*pattern == *str) {
743
+ return glob_match(pattern + 1, str + 1);
744
+ }
745
+ return false;
746
+ }
747
+
748
+ bool glob_match(const std::string & pattern, const std::string & str) {
749
+ return glob_match(pattern.c_str(), str.c_str());
750
+ }
751
+
752
+ //
753
+ // Filesystem utils
754
+ //
755
+
756
+ // Validate if a filename is safe to use
757
+ // To validate a full path, split the path by the OS-specific path separator, and validate each part with this function
758
+ bool fs_validate_filename(const std::string & filename, bool allow_subdirs) {
759
+ if (!filename.length()) {
760
+ // Empty filename invalid
761
+ return false;
762
+ }
763
+ if (filename.length() > 255) {
764
+ // Limit at common largest possible filename on Linux filesystems
765
+ // to avoid unnecessary further validation
766
+ // (On systems with smaller limits it will be caught by the OS)
767
+ return false;
768
+ }
769
+
770
+ size_t offset = 0;
771
+ while (offset < filename.size()) {
772
+ utf8_parse_result result = common_parse_utf8_codepoint(filename, offset);
773
+
774
+ if (result.status != utf8_parse_result::SUCCESS) {
775
+ return false;
776
+ }
777
+ uint32_t c = result.codepoint;
778
+
779
+ if ((result.bytes_consumed == 2 && c < 0x80) ||
780
+ (result.bytes_consumed == 3 && c < 0x800) ||
781
+ (result.bytes_consumed == 4 && c < 0x10000)) {
782
+ return false;
783
+ }
784
+
785
+ // Check for forbidden codepoints:
786
+ // - Control characters
787
+ // - Unicode equivalents of illegal characters
788
+ // - UTF-16 surrogate pairs
789
+ // - UTF-8 replacement character
790
+ // - Byte order mark (BOM)
791
+ // - Illegal characters: / \ : * ? " < > |
792
+ if (c <= 0x1F // Control characters (C0)
793
+ || c == 0x7F // Control characters (DEL)
794
+ || (c >= 0x80 && c <= 0x9F) // Control characters (C1)
795
+ || c == 0xFF0E // Fullwidth Full Stop (period equivalent)
796
+ || c == 0x2215 // Division Slash (forward slash equivalent)
797
+ || c == 0x2216 // Set Minus (backslash equivalent)
798
+ || (c >= 0xD800 && c <= 0xDFFF) // UTF-16 surrogate pairs
799
+ || c > 0x10FFFF // Max Unicode limit
800
+ || c == 0xFFFD // Replacement Character (UTF-8)
801
+ || c == 0xFEFF // Byte Order Mark (BOM)
802
+ || c == ':' || c == '*' // Illegal characters
803
+ || c == '?' || c == '"' || c == '<' || c == '>' || c == '|') {
804
+ return false;
805
+ }
806
+ if (!allow_subdirs && (c == '/' || c == '\\')) {
807
+ // Subdirectories not allowed, reject path separators
808
+ return false;
809
+ }
810
+ offset += result.bytes_consumed;
811
+ }
812
+
813
+ // Reject any leading or trailing ' ', or any trailing '.', these are stripped on Windows and will cause a different filename
814
+ // Unicode and other whitespace is not affected, only 0x20 space
815
+ if (filename.front() == ' ' || filename.back() == ' ' || filename.back() == '.') {
816
+ return false;
817
+ }
818
+
819
+ // Reject any ".." (currently stricter than necessary, it should be fine to just check for == ".." instead)
820
+ if (filename.find("..") != std::string::npos) {
821
+ return false;
822
+ }
823
+
824
+ // Reject "."
825
+ if (filename == ".") {
826
+ return false;
827
+ }
828
+
829
+ return true;
830
+ }
831
+
832
+ #include <iostream>
833
+
834
+
835
+ #ifdef _WIN32
836
+ static std::wstring utf8_to_wstring(const std::string & str) {
837
+ if (str.empty()) {
838
+ return std::wstring();
839
+ }
840
+
841
+ int size = MultiByteToWideChar(CP_UTF8, 0, str.c_str(), (int)str.size(), NULL, 0);
842
+
843
+ if (size <= 0) {
844
+ return std::wstring();
845
+ }
846
+
847
+ std::wstring wstr(size, 0);
848
+ MultiByteToWideChar(CP_UTF8, 0, str.c_str(), (int)str.size(), &wstr[0], size);
849
+
850
+ return wstr;
851
+ }
852
+ #endif
853
+
854
+ // returns true if successful, false otherwise
855
+ bool fs_create_directory_with_parents(const std::string & path) {
856
+ #ifdef _WIN32
857
+ std::wstring wpath = utf8_to_wstring(path);
858
+
859
+ // if the path already exists, check whether it's a directory
860
+ const DWORD attributes = GetFileAttributesW(wpath.c_str());
861
+ if ((attributes != INVALID_FILE_ATTRIBUTES) && (attributes & FILE_ATTRIBUTE_DIRECTORY)) {
862
+ return true;
863
+ }
864
+
865
+ size_t pos_slash = 0;
866
+
867
+ // process path from front to back, procedurally creating directories
868
+ while ((pos_slash = path.find('\\', pos_slash)) != std::string::npos) {
869
+ const std::wstring subpath = wpath.substr(0, pos_slash);
870
+
871
+ pos_slash += 1;
872
+
873
+ // skip the drive letter, in some systems it can return an access denied error
874
+ if (subpath.length() == 2 && subpath[1] == ':') {
875
+ continue;
876
+ }
877
+
878
+ const bool success = CreateDirectoryW(subpath.c_str(), NULL);
879
+
880
+ if (!success) {
881
+ const DWORD error = GetLastError();
882
+
883
+ // if the path already exists, ensure that it's a directory
884
+ if (error == ERROR_ALREADY_EXISTS) {
885
+ const DWORD attributes = GetFileAttributesW(subpath.c_str());
886
+ if (attributes == INVALID_FILE_ATTRIBUTES || !(attributes & FILE_ATTRIBUTE_DIRECTORY)) {
887
+ return false;
888
+ }
889
+ } else {
890
+ return false;
891
+ }
892
+ }
893
+ }
894
+
895
+ return true;
896
+ #else
897
+ // if the path already exists, check whether it's a directory
898
+ struct stat info;
899
+ if (stat(path.c_str(), &info) == 0) {
900
+ return S_ISDIR(info.st_mode);
901
+ }
902
+
903
+ size_t pos_slash = 1; // skip leading slashes for directory creation
904
+
905
+ // process path from front to back, procedurally creating directories
906
+ while ((pos_slash = path.find('/', pos_slash)) != std::string::npos) {
907
+ const std::string subpath = path.substr(0, pos_slash);
908
+ struct stat info;
909
+
910
+ // if the path already exists, ensure that it's a directory
911
+ if (stat(subpath.c_str(), &info) == 0) {
912
+ if (!S_ISDIR(info.st_mode)) {
913
+ return false;
914
+ }
915
+ } else {
916
+ // create parent directories
917
+ const int ret = mkdir(subpath.c_str(), 0755);
918
+ if (ret != 0) {
919
+ return false;
920
+ }
921
+ }
922
+
923
+ pos_slash += 1;
924
+ }
925
+
926
+ return true;
927
+ #endif // _WIN32
928
+ }
929
+
930
+ bool fs_is_directory(const std::string & path) {
931
+ std::filesystem::path dir(path);
932
+ return std::filesystem::exists(dir) && std::filesystem::is_directory(dir);
933
+ }
934
+
935
+ std::string fs_get_cache_directory() {
936
+ std::string cache_directory = "";
937
+ auto ensure_trailing_slash = [](std::string p) {
938
+ // Make sure to add trailing slash
939
+ if (p.back() != DIRECTORY_SEPARATOR) {
940
+ p += DIRECTORY_SEPARATOR;
941
+ }
942
+ return p;
943
+ };
944
+ if (getenv("LLAMA_CACHE")) {
945
+ cache_directory = std::getenv("LLAMA_CACHE");
946
+ } else {
947
+ #if defined(__linux__) || defined(__FreeBSD__) || defined(_AIX) || \
948
+ defined(__OpenBSD__) || defined(__NetBSD__)
949
+ if (std::getenv("XDG_CACHE_HOME")) {
950
+ cache_directory = std::getenv("XDG_CACHE_HOME");
951
+ } else if (std::getenv("HOME")) {
952
+ cache_directory = std::getenv("HOME") + std::string("/.cache/");
953
+ } else {
954
+ #if defined(__linux__)
955
+ /* no $HOME is defined, fallback to getpwuid */
956
+ struct passwd *pw = getpwuid(getuid());
957
+ if ((!pw) || (!pw->pw_dir)) {
958
+ throw std::runtime_error("Failed to find $HOME directory");
959
+ }
960
+
961
+ cache_directory = std::string(pw->pw_dir) + std::string("/.cache/");
962
+ #else /* defined(__linux__) */
963
+ throw std::runtime_error("Failed to find $HOME directory");
964
+ #endif /* defined(__linux__) */
965
+ }
966
+ #elif defined(__APPLE__)
967
+ cache_directory = std::getenv("HOME") + std::string("/Library/Caches/");
968
+ #elif defined(_WIN32)
969
+ cache_directory = std::getenv("LOCALAPPDATA");
970
+ #elif defined(__EMSCRIPTEN__)
971
+ GGML_ABORT("not implemented on this platform");
972
+ #else
973
+ # error Unknown architecture
974
+ #endif
975
+ cache_directory = ensure_trailing_slash(cache_directory);
976
+ cache_directory += "llama.cpp";
977
+ }
978
+ return ensure_trailing_slash(cache_directory);
979
+ }
980
+
981
+ std::string fs_get_cache_file(const std::string & filename) {
982
+ GGML_ASSERT(filename.find(DIRECTORY_SEPARATOR) == std::string::npos);
983
+ std::string cache_directory = fs_get_cache_directory();
984
+ const bool success = fs_create_directory_with_parents(cache_directory);
985
+ if (!success) {
986
+ throw std::runtime_error("failed to create cache directory: " + cache_directory);
987
+ }
988
+ return cache_directory + filename;
989
+ }
990
+
991
+ std::vector<common_file_info> fs_list(const std::string & path, bool include_directories) {
992
+ std::vector<common_file_info> files;
993
+ if (path.empty()) return files;
994
+
995
+ std::filesystem::path dir(path);
996
+ if (!std::filesystem::exists(dir) || !std::filesystem::is_directory(dir)) {
997
+ return files;
998
+ }
999
+
1000
+ for (const auto & entry : std::filesystem::directory_iterator(dir)) {
1001
+ try {
1002
+ // Only include regular files (skip directories)
1003
+ const auto & p = entry.path();
1004
+ if (std::filesystem::is_regular_file(p)) {
1005
+ common_file_info info;
1006
+ info.path = p.string();
1007
+ info.name = p.filename().string();
1008
+ info.is_dir = false;
1009
+ try {
1010
+ info.size = static_cast<size_t>(std::filesystem::file_size(p));
1011
+ } catch (const std::filesystem::filesystem_error &) {
1012
+ info.size = 0;
1013
+ }
1014
+ files.push_back(std::move(info));
1015
+ } else if (include_directories && std::filesystem::is_directory(p)) {
1016
+ common_file_info info;
1017
+ info.path = p.string();
1018
+ info.name = p.filename().string();
1019
+ info.size = 0; // Directories have no size
1020
+ info.is_dir = true;
1021
+ files.push_back(std::move(info));
1022
+ }
1023
+ } catch (const std::filesystem::filesystem_error &) {
1024
+ // skip entries we cannot inspect
1025
+ continue;
1026
+ }
1027
+ }
1028
+
1029
+ return files;
1030
+ }
1031
+
1032
+ //
1033
+ // TTY utils
1034
+ //
1035
+
1036
+ bool tty_can_use_colors() {
1037
+ // Check NO_COLOR environment variable (https://no-color.org/)
1038
+ if (const char * no_color = std::getenv("NO_COLOR")) {
1039
+ if (no_color[0] != '\0') {
1040
+ return false;
1041
+ }
1042
+ }
1043
+
1044
+ // Check TERM environment variable
1045
+ if (const char * term = std::getenv("TERM")) {
1046
+ if (std::strcmp(term, "dumb") == 0) {
1047
+ return false;
1048
+ }
1049
+ }
1050
+
1051
+ // Check if stdout and stderr are connected to a terminal
1052
+ // We check both because log messages can go to either
1053
+ bool stdout_is_tty = isatty(fileno(stdout));
1054
+ bool stderr_is_tty = isatty(fileno(stderr));
1055
+
1056
+ return stdout_is_tty || stderr_is_tty;
1057
+ }
1058
+
1059
+ //
1060
+ // Model utils
1061
+ //
1062
+
1063
+ // TODO: move to common/sampling
1064
+ static void common_init_sampler_from_model(
1065
+ const llama_model * model,
1066
+ common_params_sampling & sparams) {
1067
+
1068
+ const uint64_t config = sparams.user_sampling_config;
1069
+
1070
+ auto get_int32 = [&](const char * key, int32_t & dst, uint64_t user_config) {
1071
+ if (config & user_config) {
1072
+ return;
1073
+ }
1074
+
1075
+ char buf[64] = {0};
1076
+ if (llama_model_meta_val_str(model, key, buf, sizeof(buf)) > 0) {
1077
+ char * end = nullptr;
1078
+ int32_t v = strtol(buf, &end, 10);
1079
+ if (end && end != buf) {
1080
+ dst = v;
1081
+ }
1082
+ }
1083
+ };
1084
+
1085
+ auto get_float = [&](const char * key, float & dst, uint64_t user_config) {
1086
+ if (config & user_config) {
1087
+ return;
1088
+ }
1089
+
1090
+ char buf[128] = {0};
1091
+ if (llama_model_meta_val_str(model, key, buf, sizeof(buf)) > 0) {
1092
+ char * end = nullptr;
1093
+ float v = strtof(buf, &end);
1094
+ if (end && end != buf) {
1095
+ dst = v;
1096
+ }
1097
+ }
1098
+ };
1099
+
1100
+ // Sampling sequence
1101
+ if (!(config & common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_SAMPLERS)) {
1102
+ char buf[512] = {0};
1103
+ if (llama_model_meta_val_str(model, llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_SEQUENCE), buf, sizeof(buf)) > 0) {
1104
+ const std::vector<std::string> sampler_names = string_split<std::string>(std::string(buf), ';');
1105
+ if (!sampler_names.empty()) {
1106
+ sparams.samplers = common_sampler_types_from_names(sampler_names, true);
1107
+ }
1108
+ }
1109
+ }
1110
+
1111
+ get_int32(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_TOP_K), sparams.top_k, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_TOP_K);
1112
+ get_float(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_TOP_P), sparams.top_p, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_TOP_P);
1113
+ get_float(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_MIN_P), sparams.min_p, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_MIN_P);
1114
+ get_float(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_XTC_PROBABILITY), sparams.xtc_probability, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_XTC_PROBABILITY);
1115
+ get_float(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_XTC_THRESHOLD), sparams.xtc_threshold, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_XTC_THRESHOLD);
1116
+ get_float(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_TEMP), sparams.temp, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_TEMP);
1117
+ get_int32(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_PENALTY_LAST_N), sparams.penalty_last_n, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_PENALTY_LAST_N);
1118
+ get_float(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_PENALTY_REPEAT), sparams.penalty_repeat, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_PENALTY_REPEAT);
1119
+ get_int32(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_MIROSTAT), sparams.mirostat, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_MIROSTAT);
1120
+ get_float(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_MIROSTAT_TAU), sparams.mirostat_tau, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_MIROSTAT_TAU);
1121
+ get_float(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_MIROSTAT_ETA), sparams.mirostat_eta, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_MIROSTAT_ETA);
1122
+ }
1123
+
1124
+ struct common_init_result::impl {
1125
+ impl() = default;
1126
+ ~impl() = default;
1127
+
1128
+ // note: the order in which model, context, etc. are declared matters because their destructors will be called bottom-to-top
1129
+
1130
+ llama_model_ptr model;
1131
+ llama_context_ptr context;
1132
+
1133
+ std::vector<llama_adapter_lora_ptr> lora;
1134
+
1135
+ std::vector<common_sampler_ptr> samplers;
1136
+ std::vector<llama_sampler_seq_config> samplers_seq_config;
1137
+
1138
+ // Expert cache for MoE models (optional)
1139
+ std::unique_ptr<llama_expert_cache_ctx> expert_cache;
1140
+ };
1141
+
1142
+ common_init_result::common_init_result(common_params & params) :
1143
+ pimpl(new impl{}) {
1144
+ auto mparams = common_model_params_to_llama(params);
1145
+ auto cparams = common_context_params_to_llama(params);
1146
+
1147
+ if (params.fit_params) {
1148
+ LOG_INF("%s: fitting params to device memory, for bugs during this step try to reproduce them with -fit off, or provide --verbose logs if the bug only occurs with -fit on\n", __func__);
1149
+ llama_params_fit(params.model.path.c_str(), &mparams, &cparams,
1150
+ params.tensor_split,
1151
+ params.tensor_buft_overrides.data(),
1152
+ params.fit_params_target.data(),
1153
+ params.fit_params_min_ctx,
1154
+ params.verbosity >= 4 ? GGML_LOG_LEVEL_DEBUG : GGML_LOG_LEVEL_ERROR);
1155
+ }
1156
+
1157
+ llama_model * model = llama_model_load_from_file(params.model.path.c_str(), mparams);
1158
+ if (model == NULL) {
1159
+ return;
1160
+ }
1161
+
1162
+ pimpl->model.reset(model);
1163
+
1164
+ const llama_vocab * vocab = llama_model_get_vocab(model);
1165
+
1166
+ // load and optionally apply lora adapters
1167
+ for (auto & la : params.lora_adapters) {
1168
+ llama_adapter_lora_ptr lora;
1169
+ lora.reset(llama_adapter_lora_init(model, la.path.c_str()));
1170
+ if (lora == nullptr) {
1171
+ LOG_ERR("%s: failed to load lora adapter '%s'\n", __func__, la.path.c_str());
1172
+ pimpl->model.reset(model);
1173
+ return;
1174
+ }
1175
+
1176
+ char buf[1024];
1177
+ la.ptr = lora.get();
1178
+ llama_adapter_meta_val_str(la.ptr, "adapter.lora.task_name", buf, sizeof(buf));
1179
+ la.task_name = buf;
1180
+ llama_adapter_meta_val_str(la.ptr, "adapter.lora.prompt_prefix", buf, sizeof(buf));
1181
+ la.prompt_prefix = buf;
1182
+ pimpl->lora.emplace_back(std::move(lora)); // copy to list of loaded adapters
1183
+ }
1184
+
1185
+ // updates params.sampling
1186
+ // TODO: fix naming
1187
+ common_init_sampler_from_model(model, params.sampling);
1188
+
1189
+ if (params.sampling.ignore_eos && llama_vocab_eos(vocab) == LLAMA_TOKEN_NULL) {
1190
+ LOG_WRN("%s: warning: vocab does not have an EOS token, ignoring --ignore-eos\n", __func__);
1191
+ params.sampling.ignore_eos = false;
1192
+ }
1193
+
1194
+ // initialize once
1195
+ for (llama_token i = 0; i < llama_vocab_n_tokens(vocab); i++) {
1196
+ if (llama_vocab_is_eog(vocab, i)) {
1197
+ LOG_INF("%s: added %s logit bias = %f\n", __func__, common_token_to_piece(vocab, i).c_str(), -INFINITY);
1198
+ params.sampling.logit_bias_eog.push_back({i, -INFINITY});
1199
+ }
1200
+ }
1201
+
1202
+ if (params.sampling.ignore_eos) {
1203
+ // add EOG biases to the active set of logit biases
1204
+ params.sampling.logit_bias.insert(
1205
+ params.sampling.logit_bias.end(),
1206
+ params.sampling.logit_bias_eog.begin(), params.sampling.logit_bias_eog.end());
1207
+ }
1208
+
1209
+ //if (params.sampling.penalty_last_n == -1) {
1210
+ // LOG_INF("%s: setting penalty_last_n to ctx_size = %d\n", __func__, llama_n_ctx(lctx));
1211
+ // params.sampling.penalty_last_n = llama_n_ctx(lctx);
1212
+ //}
1213
+
1214
+ //if (params.sampling.dry_penalty_last_n == -1) {
1215
+ // LOG_INF("%s: setting dry_penalty_last_n to ctx_size = %d\n", __func__, llama_n_ctx(lctx));
1216
+ // params.sampling.dry_penalty_last_n = llama_n_ctx(lctx);
1217
+ //}
1218
+
1219
+ // init the backend samplers as part of the context creation
1220
+ pimpl->samplers.resize(cparams.n_seq_max);
1221
+ pimpl->samplers_seq_config.resize(cparams.n_seq_max);
1222
+
1223
+ for (int i = 0; i < (int) cparams.n_seq_max; ++i) {
1224
+ pimpl->samplers[i].reset(common_sampler_init(model, params.sampling));
1225
+ pimpl->samplers_seq_config[i] = { i, common_sampler_get(pimpl->samplers[i].get()) };
1226
+ }
1227
+
1228
+ if (params.sampling.backend_sampling) {
1229
+ cparams.samplers = pimpl->samplers_seq_config.data();
1230
+ cparams.n_samplers = pimpl->samplers_seq_config.size();
1231
+ }
1232
+
1233
+ // Initialize expert cache for MoE models if requested
1234
+ if (params.expert_cache_size > 0) {
1235
+ const auto * lmodel = reinterpret_cast<const llama_model *>(model);
1236
+ if (lmodel->hparams.n_expert > 0 && lmodel->hparams.n_expert_used > 0) {
1237
+ pimpl->expert_cache = std::make_unique<llama_expert_cache_ctx>();
1238
+ pimpl->expert_cache->init(*lmodel, params.expert_cache_size);
1239
+
1240
+ // Set eval callback to intercept ggml_mul_mat_id
1241
+ cparams.cb_eval = llama_expert_cache_ctx::eval_callback;
1242
+ cparams.cb_eval_user_data = pimpl->expert_cache.get();
1243
+
1244
+ LOG_INF("%s: expert cache enabled: %.1f MB for %d experts\n",
1245
+ __func__,
1246
+ (double)params.expert_cache_size / (1024*1024),
1247
+ (int)lmodel->hparams.n_expert);
1248
+ } else {
1249
+ LOG_WRN("%s: --expert-cache-size specified but model has no experts\n", __func__);
1250
+ }
1251
+ }
1252
+
1253
+ llama_context * lctx = llama_init_from_model(model, cparams);
1254
+ if (lctx == NULL) {
1255
+ LOG_ERR("%s: failed to create context with model '%s'\n", __func__, params.model.path.c_str());
1256
+ return;
1257
+ }
1258
+
1259
+ pimpl->context.reset(lctx);
1260
+ }
1261
+
1262
+ llama_model * common_init_result::model() {
1263
+ return pimpl->model.get();
1264
+ }
1265
+
1266
+ llama_context * common_init_result::context() {
1267
+ return pimpl->context.get();
1268
+ }
1269
+
1270
+ common_sampler * common_init_result::sampler(llama_seq_id seq_id) {
1271
+ return pimpl->samplers[seq_id].get();
1272
+ }
1273
+
1274
+ void common_init_result::reset_samplers() {
1275
+ for (int i = 0; i < (int) pimpl->samplers.size(); ++i) {
1276
+ llama_sampler_reset(common_sampler_get(pimpl->samplers[i].get()));
1277
+ }
1278
+ }
1279
+
1280
+ std::vector<llama_adapter_lora_ptr> & common_init_result::lora() {
1281
+ return pimpl->lora;
1282
+ }
1283
+
1284
+ common_init_result_ptr common_init_from_params(common_params & params) {
1285
+ common_init_result_ptr res(new common_init_result(params));
1286
+
1287
+ llama_model * model = res->model();
1288
+ if (model == NULL) {
1289
+ LOG_ERR("%s: failed to load model '%s'\n", __func__, params.model.path.c_str());
1290
+ return res;
1291
+ }
1292
+
1293
+ llama_context * lctx = res->context();
1294
+ if (lctx == NULL) {
1295
+ LOG_ERR("%s: failed to create context with model '%s'\n", __func__, params.model.path.c_str());
1296
+ return res;
1297
+ }
1298
+
1299
+ const llama_vocab * vocab = llama_model_get_vocab(model);
1300
+
1301
+ if (params.ctx_shift && !llama_memory_can_shift(llama_get_memory(lctx))) {
1302
+ LOG_WRN("%s: KV cache shifting is not supported for this context, disabling KV cache shifting\n", __func__);
1303
+ params.ctx_shift = false;
1304
+ }
1305
+
1306
+ if (!params.control_vectors.empty()) {
1307
+ if (params.control_vector_layer_start <= 0) params.control_vector_layer_start = 1;
1308
+ if (params.control_vector_layer_end <= 0) params.control_vector_layer_end = llama_model_n_layer(model);
1309
+
1310
+ const auto cvec = common_control_vector_load(params.control_vectors);
1311
+ if (cvec.n_embd == -1) {
1312
+ return res;
1313
+ }
1314
+
1315
+ int err = llama_set_adapter_cvec(
1316
+ lctx,
1317
+ cvec.data.data(),
1318
+ cvec.data.size(),
1319
+ cvec.n_embd,
1320
+ params.control_vector_layer_start,
1321
+ params.control_vector_layer_end);
1322
+ if (err) {
1323
+ return res;
1324
+ }
1325
+ }
1326
+
1327
+ if (llama_pooling_type(lctx) == LLAMA_POOLING_TYPE_RANK) {
1328
+ bool ok = true;
1329
+
1330
+ if (llama_vocab_bos(vocab) == LLAMA_TOKEN_NULL) {
1331
+ LOG_WRN("%s: warning: vocab does not have a BOS token, reranking will not work\n", __func__);
1332
+ ok = false;
1333
+ }
1334
+
1335
+ bool has_eos = llama_vocab_eos(vocab) != LLAMA_TOKEN_NULL;
1336
+ bool has_sep = llama_vocab_sep(vocab) != LLAMA_TOKEN_NULL;
1337
+ bool has_rerank_prompt = llama_model_chat_template(model, "rerank") != NULL;
1338
+
1339
+ if (!has_eos && !has_sep && !has_rerank_prompt) {
1340
+ LOG_WRN("%s: warning: vocab does not have an EOS token, SEP token, or rerank prompt. Reranking will not work\n", __func__);
1341
+ ok = false;
1342
+ } else if (!has_eos) {
1343
+ LOG_WRN("%s: warning: vocab does not have an EOS token, using SEP token as fallback\n", __func__);
1344
+ }
1345
+
1346
+ if (!ok) {
1347
+ return res;
1348
+ }
1349
+ }
1350
+
1351
+ if (!params.lora_init_without_apply) {
1352
+ common_set_adapter_lora(lctx, params.lora_adapters);
1353
+ }
1354
+
1355
+ if (params.warmup) {
1356
+ LOG_WRN("%s: warming up the model with an empty run - please wait ... (--no-warmup to disable)\n", __func__);
1357
+
1358
+ llama_set_warmup(lctx, true);
1359
+
1360
+ std::vector<llama_token> tmp;
1361
+ llama_token bos = llama_vocab_bos(vocab);
1362
+ llama_token eos = llama_vocab_eos(vocab);
1363
+
1364
+ // some models (e.g. T5) don't have a BOS token
1365
+ if (bos != LLAMA_TOKEN_NULL) {
1366
+ tmp.push_back(bos);
1367
+ }
1368
+ if (eos != LLAMA_TOKEN_NULL) {
1369
+ tmp.push_back(eos);
1370
+ }
1371
+ if (tmp.empty()) {
1372
+ tmp.push_back(0);
1373
+ }
1374
+
1375
+ if (llama_model_has_encoder(model)) {
1376
+ llama_encode(lctx, llama_batch_get_one(tmp.data(), tmp.size()));
1377
+ llama_token decoder_start_token_id = llama_model_decoder_start_token(model);
1378
+ if (decoder_start_token_id == LLAMA_TOKEN_NULL) {
1379
+ decoder_start_token_id = bos;
1380
+ }
1381
+ tmp.clear();
1382
+ tmp.push_back(decoder_start_token_id);
1383
+ }
1384
+ if (llama_model_has_decoder(model)) {
1385
+ llama_decode(lctx, llama_batch_get_one(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch)));
1386
+ }
1387
+ llama_memory_clear(llama_get_memory(lctx), true);
1388
+ llama_synchronize(lctx);
1389
+ llama_perf_context_reset(lctx);
1390
+ llama_set_warmup(lctx, false);
1391
+
1392
+ // reset samplers to reset RNG state after warmup to the seeded state
1393
+ res->reset_samplers();
1394
+ }
1395
+
1396
+ return res;
1397
+ }
1398
+
1399
+ common_init_result::~common_init_result() = default;
1400
+
1401
+ std::string get_model_endpoint() {
1402
+ const char * model_endpoint_env = getenv("MODEL_ENDPOINT");
1403
+ // We still respect the use of environment-variable "HF_ENDPOINT" for backward-compatibility.
1404
+ const char * hf_endpoint_env = getenv("HF_ENDPOINT");
1405
+ const char * endpoint_env = model_endpoint_env ? model_endpoint_env : hf_endpoint_env;
1406
+ std::string model_endpoint = "https://huggingface.co/";
1407
+ if (endpoint_env) {
1408
+ model_endpoint = endpoint_env;
1409
+ if (model_endpoint.back() != '/') {
1410
+ model_endpoint += '/';
1411
+ }
1412
+ }
1413
+ return model_endpoint;
1414
+ }
1415
+
1416
+ void common_set_adapter_lora(struct llama_context * ctx, std::vector<common_adapter_lora_info> & lora) {
1417
+ std::vector<llama_adapter_lora *> loras;
1418
+ std::vector<float> scales;
1419
+
1420
+ for (auto & la: lora) {
1421
+ loras.push_back(la.ptr);
1422
+ scales.push_back(la.scale);
1423
+ }
1424
+
1425
+ llama_set_adapters_lora(ctx, loras.data(), loras.size(), scales.data());
1426
+ }
1427
+
1428
+ struct llama_model_params common_model_params_to_llama(common_params & params) {
1429
+ auto mparams = llama_model_default_params();
1430
+
1431
+ if (!params.devices.empty()) {
1432
+ mparams.devices = params.devices.data();
1433
+ }
1434
+
1435
+ mparams.n_gpu_layers = params.n_gpu_layers;
1436
+ mparams.main_gpu = params.main_gpu;
1437
+ mparams.split_mode = params.split_mode;
1438
+ mparams.tensor_split = params.tensor_split;
1439
+ mparams.use_mmap = params.use_mmap;
1440
+ mparams.use_direct_io = params.use_direct_io;
1441
+ mparams.use_mlock = params.use_mlock;
1442
+ mparams.check_tensors = params.check_tensors;
1443
+ mparams.use_extra_bufts = !params.no_extra_bufts;
1444
+ mparams.no_host = params.no_host;
1445
+
1446
+ if (params.kv_overrides.empty()) {
1447
+ mparams.kv_overrides = NULL;
1448
+ } else {
1449
+ GGML_ASSERT(params.kv_overrides.back().key[0] == 0 && "KV overrides not terminated with empty key");
1450
+ mparams.kv_overrides = params.kv_overrides.data();
1451
+ }
1452
+
1453
+ if (params.tensor_buft_overrides.empty()) {
1454
+ mparams.tensor_buft_overrides = NULL;
1455
+ } else {
1456
+ GGML_ASSERT(params.tensor_buft_overrides.back().pattern == nullptr && "Tensor buffer overrides not terminated with empty pattern");
1457
+ mparams.tensor_buft_overrides = params.tensor_buft_overrides.data();
1458
+ }
1459
+
1460
+ mparams.progress_callback = params.load_progress_callback;
1461
+ mparams.progress_callback_user_data = params.load_progress_callback_user_data;
1462
+
1463
+ return mparams;
1464
+ }
1465
+
1466
+ struct llama_context_params common_context_params_to_llama(const common_params & params) {
1467
+ auto cparams = llama_context_default_params();
1468
+
1469
+ cparams.n_ctx = params.n_ctx;
1470
+ cparams.n_seq_max = params.n_parallel;
1471
+ cparams.n_batch = params.n_batch;
1472
+ cparams.n_ubatch = params.n_ubatch;
1473
+ cparams.n_threads = params.cpuparams.n_threads;
1474
+ cparams.n_threads_batch = params.cpuparams_batch.n_threads == -1 ?
1475
+ params.cpuparams.n_threads : params.cpuparams_batch.n_threads;
1476
+ cparams.embeddings = params.embedding;
1477
+ cparams.rope_scaling_type = params.rope_scaling_type;
1478
+ cparams.rope_freq_base = params.rope_freq_base;
1479
+ cparams.rope_freq_scale = params.rope_freq_scale;
1480
+ cparams.yarn_ext_factor = params.yarn_ext_factor;
1481
+ cparams.yarn_attn_factor = params.yarn_attn_factor;
1482
+ cparams.yarn_beta_fast = params.yarn_beta_fast;
1483
+ cparams.yarn_beta_slow = params.yarn_beta_slow;
1484
+ cparams.yarn_orig_ctx = params.yarn_orig_ctx;
1485
+ cparams.pooling_type = params.pooling_type;
1486
+ cparams.attention_type = params.attention_type;
1487
+ cparams.flash_attn_type = params.flash_attn_type;
1488
+ cparams.cb_eval = params.cb_eval;
1489
+ cparams.cb_eval_user_data = params.cb_eval_user_data;
1490
+ cparams.offload_kqv = !params.no_kv_offload;
1491
+ cparams.no_perf = params.no_perf;
1492
+ cparams.op_offload = !params.no_op_offload;
1493
+ cparams.swa_full = params.swa_full;
1494
+ cparams.kv_unified = params.kv_unified;
1495
+
1496
+ cparams.type_k = params.cache_type_k;
1497
+ cparams.type_v = params.cache_type_v;
1498
+
1499
+ return cparams;
1500
+ }
1501
+
1502
+ struct ggml_threadpool_params ggml_threadpool_params_from_cpu_params(const cpu_params & params) {
1503
+ struct ggml_threadpool_params tpp;
1504
+
1505
+ ggml_threadpool_params_init(&tpp, params.n_threads); // setup the defaults
1506
+
1507
+ if (params.mask_valid) {
1508
+ std::memcpy(&tpp.cpumask, &params.cpumask, GGML_MAX_N_THREADS);
1509
+ }
1510
+
1511
+ tpp.prio = params.priority;
1512
+ tpp.poll = params.poll;
1513
+ tpp.strict_cpu = params.strict_cpu;
1514
+
1515
+ return tpp;
1516
+ }
1517
+
1518
+ //
1519
+ // Batch utils
1520
+ //
1521
+
1522
+ void common_batch_clear(struct llama_batch & batch) {
1523
+ batch.n_tokens = 0;
1524
+ }
1525
+
1526
+ void common_batch_add(
1527
+ struct llama_batch & batch,
1528
+ llama_token id,
1529
+ llama_pos pos,
1530
+ const std::vector<llama_seq_id> & seq_ids,
1531
+ bool logits) {
1532
+ GGML_ASSERT(batch.seq_id[batch.n_tokens] && "llama_batch size exceeded");
1533
+
1534
+ batch.token [batch.n_tokens] = id;
1535
+ batch.pos [batch.n_tokens] = pos;
1536
+ batch.n_seq_id[batch.n_tokens] = seq_ids.size();
1537
+ for (size_t i = 0; i < seq_ids.size(); ++i) {
1538
+ batch.seq_id[batch.n_tokens][i] = seq_ids[i];
1539
+ }
1540
+ batch.logits [batch.n_tokens] = logits;
1541
+
1542
+ batch.n_tokens++;
1543
+ }
1544
+
1545
+ //
1546
+ // Vocab utils
1547
+ //
1548
+
1549
+ std::vector<llama_token> common_tokenize(
1550
+ const struct llama_context * ctx,
1551
+ const std::string & text,
1552
+ bool add_special,
1553
+ bool parse_special) {
1554
+ const llama_model * model = llama_get_model(ctx);
1555
+ const llama_vocab * vocab = llama_model_get_vocab(model);
1556
+ return common_tokenize(vocab, text, add_special, parse_special);
1557
+ }
1558
+
1559
+ std::vector<llama_token> common_tokenize(
1560
+ const struct llama_vocab * vocab,
1561
+ const std::string & text,
1562
+ bool add_special,
1563
+ bool parse_special) {
1564
+ // upper limit for the number of tokens
1565
+ int n_tokens = text.length() + 2 * add_special;
1566
+ std::vector<llama_token> result(n_tokens);
1567
+ n_tokens = llama_tokenize(vocab, text.data(), text.length(), result.data(), result.size(), add_special, parse_special);
1568
+ if (n_tokens == std::numeric_limits<int32_t>::min()) {
1569
+ throw std::runtime_error("Tokenization failed: input text too large, tokenization result exceeds int32_t limit");
1570
+ }
1571
+ if (n_tokens < 0) {
1572
+ result.resize(-n_tokens);
1573
+ int check = llama_tokenize(vocab, text.data(), text.length(), result.data(), result.size(), add_special, parse_special);
1574
+ GGML_ASSERT(check == -n_tokens);
1575
+ } else {
1576
+ result.resize(n_tokens);
1577
+ }
1578
+ return result;
1579
+ }
1580
+
1581
+ std::string common_token_to_piece(const struct llama_context * ctx, llama_token token, bool special) {
1582
+ const llama_model * model = llama_get_model(ctx);
1583
+ const llama_vocab * vocab = llama_model_get_vocab(model);
1584
+ return common_token_to_piece(vocab, token, special);
1585
+ }
1586
+
1587
+ std::string common_token_to_piece(const struct llama_vocab * vocab, llama_token token, bool special) {
1588
+ std::string piece;
1589
+ piece.resize(piece.capacity()); // using string internal cache, 15 bytes + '\n'
1590
+ const int n_chars = llama_token_to_piece(vocab, token, &piece[0], piece.size(), 0, special);
1591
+ if (n_chars < 0) {
1592
+ piece.resize(-n_chars);
1593
+ int check = llama_token_to_piece(vocab, token, &piece[0], piece.size(), 0, special);
1594
+ GGML_ASSERT(check == -n_chars);
1595
+ }
1596
+ else {
1597
+ piece.resize(n_chars);
1598
+ }
1599
+
1600
+ return piece;
1601
+ }
1602
+
1603
+ std::string common_detokenize(const struct llama_context * ctx, const std::vector<llama_token> & tokens, bool special) {
1604
+ const llama_model * model = llama_get_model(ctx);
1605
+ const llama_vocab * vocab = llama_model_get_vocab(model);
1606
+ return common_detokenize(vocab, tokens, special);
1607
+ }
1608
+
1609
+ std::string common_detokenize(const struct llama_vocab * vocab, const std::vector<llama_token> & tokens, bool special) {
1610
+ std::string text;
1611
+ text.resize(std::max(text.capacity(), tokens.size()));
1612
+ int32_t n_chars = llama_detokenize(vocab, tokens.data(), (int32_t)tokens.size(), &text[0], (int32_t)text.size(), false, special);
1613
+ if (n_chars < 0) {
1614
+ text.resize(-n_chars);
1615
+ n_chars = llama_detokenize(vocab, tokens.data(), (int32_t)tokens.size(), &text[0], (int32_t)text.size(), false, special);
1616
+ GGML_ASSERT(n_chars <= (int32_t)text.size()); // whitespace trimming is performed after per-token detokenization
1617
+ }
1618
+
1619
+ text.resize(n_chars);
1620
+
1621
+ // NOTE: the original tokenizer decodes bytes after collecting the pieces.
1622
+ return text;
1623
+ }
1624
+
1625
+ //
1626
+ // Embedding utils
1627
+ //
1628
+
1629
+ void common_embd_normalize(const float * inp, float * out, int n, int embd_norm) {
1630
+ double sum = 0.0;
1631
+
1632
+ switch (embd_norm) {
1633
+ case -1: // no normalisation
1634
+ sum = 1.0;
1635
+ break;
1636
+ case 0: // max absolute
1637
+ for (int i = 0; i < n; i++) {
1638
+ if (sum < std::abs(inp[i])) {
1639
+ sum = std::abs(inp[i]);
1640
+ }
1641
+ }
1642
+ sum /= 32760.0; // make an int16 range
1643
+ break;
1644
+ case 2: // euclidean
1645
+ for (int i = 0; i < n; i++) {
1646
+ sum += inp[i] * inp[i];
1647
+ }
1648
+ sum = std::sqrt(sum);
1649
+ break;
1650
+ default: // p-norm (euclidean is p-norm p=2)
1651
+ for (int i = 0; i < n; i++) {
1652
+ sum += std::pow(std::abs(inp[i]), embd_norm);
1653
+ }
1654
+ sum = std::pow(sum, 1.0 / embd_norm);
1655
+ break;
1656
+ }
1657
+
1658
+ const float norm = sum > 0.0 ? 1.0 / sum : 0.0f;
1659
+
1660
+ for (int i = 0; i < n; i++) {
1661
+ out[i] = inp[i] * norm;
1662
+ }
1663
+ }
1664
+
1665
+ float common_embd_similarity_cos(const float * embd1, const float * embd2, int n){
1666
+ double sum = 0.0;
1667
+ double sum1 = 0.0;
1668
+ double sum2 = 0.0;
1669
+
1670
+ for (int i = 0; i < n; i++) {
1671
+ sum += embd1[i] * embd2[i];
1672
+ sum1 += embd1[i] * embd1[i];
1673
+ sum2 += embd2[i] * embd2[i];
1674
+ }
1675
+
1676
+ // Handle the case where one or both vectors are zero vectors
1677
+ if (sum1 == 0.0 || sum2 == 0.0) {
1678
+ if (sum1 == 0.0 && sum2 == 0.0) {
1679
+ return 1.0f; // two zero vectors are similar
1680
+ }
1681
+ return 0.0f;
1682
+ }
1683
+
1684
+ return sum / (sqrt(sum1) * sqrt(sum2));
1685
+ }
1686
+
1687
+ //
1688
+ // Control vector utils
1689
+ //
1690
+
1691
+ static common_control_vector_data common_control_vector_load_one(const common_control_vector_load_info & load_info) {
1692
+ common_control_vector_data result = { -1, {} };
1693
+
1694
+ ggml_context * ctx = nullptr;
1695
+ struct gguf_init_params meta_gguf_params = {
1696
+ /* .no_alloc = */ false,
1697
+ /* .ctx = */ &ctx,
1698
+ };
1699
+ struct gguf_context * ctx_gguf = gguf_init_from_file(load_info.fname.c_str(), meta_gguf_params);
1700
+ if (!ctx_gguf) {
1701
+ LOG_ERR("%s: failed to load control vector file from %s\n", __func__, load_info.fname.c_str());
1702
+ return result;
1703
+ }
1704
+
1705
+ int32_t n_tensors = gguf_get_n_tensors(ctx_gguf);
1706
+ if (n_tensors == 0) {
1707
+ LOG_WRN("%s: no direction tensors found in %s\n", __func__, load_info.fname.c_str());
1708
+ }
1709
+
1710
+ for (int i = 0; i < n_tensors; i++) {
1711
+ std::string name = gguf_get_tensor_name(ctx_gguf, i);
1712
+
1713
+ int layer_idx = -1;
1714
+
1715
+ // split on '.'
1716
+ size_t dotpos = name.find('.');
1717
+ if (dotpos != std::string::npos && name.substr(0, dotpos) == "direction") {
1718
+ try {
1719
+ layer_idx = std::stoi(name.substr(dotpos + 1));
1720
+ } catch (...) {
1721
+ layer_idx = -1;
1722
+ }
1723
+ }
1724
+ if (layer_idx < 0) {
1725
+ LOG_ERR("%s: invalid/unparsable direction tensor layer index in %s\n", __func__, load_info.fname.c_str());
1726
+ result.n_embd = -1;
1727
+ break;
1728
+ } else if (layer_idx == 0) {
1729
+ LOG_ERR("%s: invalid (zero) direction tensor layer index in %s\n", __func__, load_info.fname.c_str());
1730
+ result.n_embd = -1;
1731
+ break;
1732
+ }
1733
+
1734
+ struct ggml_tensor * tensor = ggml_get_tensor(ctx, name.c_str());
1735
+ if (tensor->type != GGML_TYPE_F32) {
1736
+ LOG_ERR("%s: invalid (non-F32) direction tensor type in %s\n", __func__, load_info.fname.c_str());
1737
+ result.n_embd = -1;
1738
+ break;
1739
+ }
1740
+ if (ggml_n_dims(tensor) != 1) {
1741
+ LOG_ERR("%s: invalid (non-1D) direction tensor shape in %s\n", __func__, load_info.fname.c_str());
1742
+ result.n_embd = -1;
1743
+ break;
1744
+ }
1745
+
1746
+ if (result.n_embd == -1) {
1747
+ result.n_embd = ggml_nelements(tensor);
1748
+ } else if (ggml_nelements(tensor) != result.n_embd) {
1749
+ LOG_ERR("%s: direction tensor in %s does not match previous dimensions\n", __func__, load_info.fname.c_str());
1750
+ result.n_embd = -1;
1751
+ break;
1752
+ }
1753
+
1754
+ // extend if necessary - do not store data for layer 0 (it's not used)
1755
+ result.data.resize(std::max(result.data.size(), static_cast<size_t>(result.n_embd * layer_idx)), 0.0f);
1756
+
1757
+ const float * src = (const float *) tensor->data;
1758
+ float * dst = result.data.data() + result.n_embd * (layer_idx - 1); // layer 1 at [0]
1759
+ for (int j = 0; j < result.n_embd; j++) {
1760
+ dst[j] += src[j] * load_info.strength; // allows multiple directions for same layer in same file
1761
+ }
1762
+
1763
+ }
1764
+
1765
+ if (result.n_embd == -1) {
1766
+ LOG_WRN("%s: skipping %s due to invalid direction tensors\n", __func__, load_info.fname.c_str());
1767
+ result.data.clear();
1768
+ }
1769
+
1770
+ gguf_free(ctx_gguf);
1771
+ ggml_free(ctx);
1772
+
1773
+ return result;
1774
+ }
1775
+
1776
+ common_control_vector_data common_control_vector_load(const std::vector<common_control_vector_load_info> & load_infos) {
1777
+ common_control_vector_data result = { -1, {} };
1778
+
1779
+ for (const auto & info : load_infos) {
1780
+ auto cur = common_control_vector_load_one(info);
1781
+
1782
+ if (cur.n_embd == -1) {
1783
+ result.n_embd = -1;
1784
+ break;
1785
+ }
1786
+ if (result.n_embd != -1 && result.n_embd != cur.n_embd) {
1787
+ LOG_ERR("%s: control vectors in %s does not match previous dimensions\n", __func__, info.fname.c_str());
1788
+ result.n_embd = -1;
1789
+ break;
1790
+ }
1791
+
1792
+ if (result.n_embd == -1) {
1793
+ result = std::move(cur);
1794
+ } else {
1795
+ result.data.resize(std::max(result.data.size(), cur.data.size()), 0.0f); // extend if necessary
1796
+ for (size_t i = 0; i < cur.data.size(); i++) {
1797
+ result.data[i] += cur.data[i];
1798
+ }
1799
+ }
1800
+ }
1801
+
1802
+ if (result.n_embd == -1) {
1803
+ LOG_ERR("%s: no valid control vector files passed\n", __func__);
1804
+ result.data.clear();
1805
+ }
1806
+
1807
+ return result;
1808
+ }
1809
+
1810
+ ggml_opt_dataset_t common_opt_dataset_init(struct llama_context * ctx, const std::vector<llama_token> & tokens, int64_t stride) {
1811
+ const int64_t ne_datapoint = llama_n_ctx(ctx);
1812
+ const int64_t ndata = (tokens.size() - ne_datapoint - 1) / stride;
1813
+ ggml_opt_dataset_t result = ggml_opt_dataset_init(
1814
+ GGML_TYPE_I32, GGML_TYPE_I32, ne_datapoint, ne_datapoint, ndata, /*ndata_shard =*/ 1);
1815
+
1816
+ llama_token * data = (llama_token *) ggml_opt_dataset_data(result)->data;
1817
+ llama_token * labels = (llama_token *) ggml_opt_dataset_labels(result)->data;
1818
+
1819
+ for (int64_t idata = 0; idata < ndata; ++idata) {
1820
+ memcpy(data + idata*ne_datapoint, tokens.data() + idata*stride + 0, ne_datapoint*sizeof(llama_token));
1821
+ memcpy(labels + idata*ne_datapoint, tokens.data() + idata*stride + 1, ne_datapoint*sizeof(llama_token));
1822
+ }
1823
+
1824
+ return result;
1825
+ }
1826
+
1827
+ ggml_opt_optimizer_params common_opt_lr_pars(void * userdata) {
1828
+ ggml_opt_optimizer_params result = ggml_opt_get_default_optimizer_params(nullptr);
1829
+ const lr_opt & d = *(lr_opt *) userdata;
1830
+ result.adamw.alpha = result.sgd.alpha = d.get_lr(d.epoch);
1831
+ result.sgd.wd = result.adamw.wd = d.wd;
1832
+ return result;
1833
+ }
1834
+
1835
+ // TODO make all command line args case-insensitive
1836
+ static inline bool eq_case_insensitive(char const* a, char const* b) {
1837
+ return !
1838
+ #if defined(_MSC_VER)
1839
+ _stricmp
1840
+ #else
1841
+ strcasecmp
1842
+ #endif // defined(_MSC_VER)
1843
+ (a, b);
1844
+ }
1845
+
1846
+ enum ggml_opt_optimizer_type common_opt_get_optimizer(const char * n) {
1847
+ if (eq_case_insensitive("adamw", n)) {
1848
+ return GGML_OPT_OPTIMIZER_TYPE_ADAMW;
1849
+ }
1850
+ if (eq_case_insensitive("sgd", n)) {
1851
+ return GGML_OPT_OPTIMIZER_TYPE_SGD;
1852
+ }
1853
+ return GGML_OPT_OPTIMIZER_TYPE_COUNT;
1854
+ }
1855
+
1856
+ // TODO simplify to use just log and exp
1857
+ static float const k_log_2 = std::log(2.f);
1858
+
1859
+ void lr_opt::init() {
1860
+ if (lr_min > 0 && lr_min < lr0) {
1861
+ float nhalf = std::log(lr0 / lr_min) / k_log_2;
1862
+ float e = epochs;
1863
+ if (decay_epochs > 0 && decay_epochs < e) {
1864
+ e = decay_epochs;
1865
+ } else {
1866
+ decay_epochs = e;
1867
+ }
1868
+ scale_epoch = nhalf / e;
1869
+ }
1870
+ }
1871
+
1872
+ float lr_opt::get_lr(float epoch) const {
1873
+ float r = lr_min <= 0 ? lr0 :
1874
+ epoch >= decay_epochs ? lr_min :
1875
+ lr0 * std::pow(0.5f, epoch * scale_epoch);
1876
+ LOG_INF("epoch %.2g lr=%.2g\n", epoch, r);
1877
+ return r;
1878
+ }
1879
+
1880
+ bool common_replay_last_token(struct llama_context * ctx, llama_token last_token, int32_t pos) {
1881
+ llama_batch batch = llama_batch_get_one(&last_token, 1);
1882
+ batch.pos = &pos;
1883
+ if (llama_decode(ctx, batch)) {
1884
+ LOG_ERR("%s: failed to replay last token\n", __func__);
1885
+ return false;
1886
+ }
1887
+ return true;
1888
+ }
1889
+
1890
+ bool common_prompt_batch_decode(
1891
+ struct llama_context * ctx,
1892
+ const std::vector<llama_token> & tokens,
1893
+ int & n_past,
1894
+ int n_batch,
1895
+ std::string_view state_path,
1896
+ bool save_state) {
1897
+ const int n_eval = tokens.size();
1898
+ if (n_eval == 0) {
1899
+ return true;
1900
+ }
1901
+
1902
+ if (save_state && n_eval > 1) {
1903
+ const int n_tokens_before_last = n_eval - 1;
1904
+
1905
+ GGML_ASSERT(n_eval <= n_batch);
1906
+
1907
+ // Decode all but the last token so we can save the memory state before decoding the last token.
1908
+ // This is done so we can restore the session state later and replay the last token.
1909
+ // Memory implementations in recurrent/hybrid models don't support removing tokens from their
1910
+ // memory, so we can't just remove the last token from the memory and replay the last token which
1911
+ // is the reason for this logic.
1912
+ if (llama_decode(ctx, llama_batch_get_one(const_cast<llama_token*>(tokens.data()), n_tokens_before_last))) {
1913
+ LOG_ERR("%s : failed to eval\n", __func__);
1914
+ return false;
1915
+ }
1916
+ n_past += n_tokens_before_last;
1917
+
1918
+ llama_state_save_file(ctx, state_path.data(), tokens.data(), n_tokens_before_last);
1919
+ LOG_INF("saved session before last token to %s, n_tokens = %d\n", state_path.data(), n_tokens_before_last);
1920
+
1921
+ llama_token last_token = tokens.back();
1922
+ llama_batch batch = llama_batch_get_one(&last_token, 1);
1923
+ int32_t pos = n_past;
1924
+ batch.pos = &pos;
1925
+
1926
+ if (llama_decode(ctx, batch)) {
1927
+ LOG_ERR("%s : failed to eval last token\n", __func__);
1928
+ return false;
1929
+ }
1930
+ n_past++;
1931
+ } else {
1932
+ if (llama_decode(ctx, llama_batch_get_one(const_cast<llama_token*>(tokens.data()), n_eval))) {
1933
+ LOG_ERR("%s : failed to eval\n", __func__);
1934
+ return false;
1935
+ }
1936
+ n_past += n_eval;
1937
+ }
1938
+
1939
+ return true;
1940
+ }