NOT-OMEGA commited on
Commit
c32e207
Β·
verified Β·
1 Parent(s): 333e2c4

Create inference.cpp

Browse files
Files changed (1) hide show
  1. inference.cpp +477 -0
inference.cpp ADDED
@@ -0,0 +1,477 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * ============================================================
3
+ * NanoMind β€” Optimized GPT-2 Inference Engine v1.0
4
+ * ============================================================
5
+ *
6
+ * Optimizations over baseline:
7
+ * βœ“ AVX2 + FMA matmul (8 floats/instruction)
8
+ * βœ“ AVX2 attention dot products (inner hs=64 loop)
9
+ * βœ“ AVX2 weighted V accumulation
10
+ * βœ“ Pre-allocated working buffers (no stack VLAs)
11
+ * βœ“ OpenMP parallelism across heads + matmul rows
12
+ * βœ“ Persistent daemon β€” model loaded ONCE
13
+ * βœ“ Per-session KV-cache with LRU eviction
14
+ *
15
+ * ── STDIN PROTOCOL ──────────────────────────────────────────
16
+ * REQUEST|<sess>|<tokens_csv>|<max_new>|<temp>|<top_k>|<stop_csv>
17
+ * RESET|<sess>
18
+ * QUIT
19
+ *
20
+ * ── STDOUT PROTOCOL ─────────────────────────────────────────
21
+ * READY
22
+ * TOKEN <id> <elapsed_ms>
23
+ * DONE <count> <total_ms>
24
+ * RESET_OK
25
+ * ERROR <message>
26
+ *
27
+ * ── COMPILE ─────────────────────────────────────────────────
28
+ * g++ -O3 -march=native -fopenmp -ffast-math -std=c++17 \
29
+ * -o inference inference.cpp -lm
30
+ * ============================================================
31
+ */
32
+
33
+ #include <iostream>
34
+ #include <cstdio>
35
+ #include <cstdlib>
36
+ #include <cmath>
37
+ #include <cstring>
38
+ #include <ctime>
39
+ #include <algorithm>
40
+ #include <string>
41
+ #include <vector>
42
+ #include <unordered_map>
43
+ #include <unordered_set>
44
+ #include <immintrin.h> // AVX2 + FMA
45
+ #ifdef _OPENMP
46
+ #include <omp.h>
47
+ #endif
48
+ #ifdef _WIN32
49
+ #include <windows.h>
50
+ static double get_ms() {
51
+ LARGE_INTEGER f, c;
52
+ QueryPerformanceFrequency(&f); QueryPerformanceCounter(&c);
53
+ return (double)c.QuadPart / f.QuadPart * 1000.0;
54
+ }
55
+ #else
56
+ #include <sys/time.h>
57
+ static double get_ms() {
58
+ struct timeval tv; gettimeofday(&tv, NULL);
59
+ return tv.tv_sec * 1000.0 + tv.tv_usec / 1000.0;
60
+ }
61
+ #endif
62
+
63
+ // ─────────────────────────────────────────────────────────────────────────
64
+ // Config & Weights
65
+ // ─────────────────────────────────────────────────────────────────────────
66
+ struct Config {
67
+ int n_layer, n_head, n_embd, block_size, vocab_size;
68
+ };
69
+
70
+ struct Weights {
71
+ float *wte, *wpe;
72
+ float **ln1_w, **ln1_b;
73
+ float **c_attn_w, **c_attn_b;
74
+ float **c_proj_w, **c_proj_b;
75
+ float **ln2_w, **ln2_b;
76
+ float **fc_w, **fc_b;
77
+ float **mlp_proj_w, **mlp_proj_b;
78
+ float *ln_f_w, *ln_f_b;
79
+ float *lm_head_w;
80
+ };
81
+
82
+ static Config cfg;
83
+ static Weights W;
84
+ static float* g_data = nullptr;
85
+
86
+ // ─────────────────────────────────────────────────────────────────────────
87
+ // Session (per-session KV-cache + position)
88
+ // ─────────────────────────────────────────────────────────────────────────
89
+ struct Session {
90
+ float* k_cache = nullptr;
91
+ float* v_cache = nullptr;
92
+ int pos = 0;
93
+ double last_use = 0.0;
94
+ };
95
+ static const int MAX_SESSIONS = 20;
96
+ static std::unordered_map<std::string, Session> g_sessions;
97
+
98
+ // ─────────────────────────────────────────────────────────────────────────
99
+ // Working Buffers β€” pre-allocated, NO stack VLAs
100
+ // ─────────────────────────────────────────────────────────────────────────
101
+ static float *g_x, *g_buf, *g_qkv, *g_attn_buf;
102
+ static float *g_ff, *g_logits;
103
+ static float *g_tmp_out; // for o_proj + mlp_proj output β€” replaces stack VLA
104
+
105
+ // ─────────────────────────────────────────────────────────────────────────
106
+ // Math Kernels
107
+ // ─────────────────────────────────────────────────────────────────────────
108
+
109
+ static void layer_norm(float* out, const float* x, const float* w,
110
+ const float* b, int N) {
111
+ float mean = 0.f, var = 0.f;
112
+ for (int i = 0; i < N; i++) mean += x[i];
113
+ mean /= N;
114
+ for (int i = 0; i < N; i++) { float d = x[i]-mean; var += d*d; }
115
+ var /= N;
116
+ float sc = 1.f / sqrtf(var + 1e-5f);
117
+ for (int i = 0; i < N; i++) out[i] = (x[i]-mean)*sc*w[i] + b[i];
118
+ }
119
+
120
+ // AVX2 + FMA: out[M] = mat[M,K] Β· x[K]
121
+ static void matmul_vec(float* __restrict__ out,
122
+ const float* __restrict__ mat,
123
+ const float* __restrict__ x,
124
+ int M, int K) {
125
+ #pragma omp parallel for schedule(static)
126
+ for (int i = 0; i < M; i++) {
127
+ const float* row = mat + (long long)i * K;
128
+ __m256 acc = _mm256_setzero_ps();
129
+ int j = 0;
130
+ for (; j <= K-8; j += 8)
131
+ acc = _mm256_fmadd_ps(_mm256_loadu_ps(row+j),
132
+ _mm256_loadu_ps(x+j), acc);
133
+ float tmp[8]; _mm256_storeu_ps(tmp, acc);
134
+ float s = tmp[0]+tmp[1]+tmp[2]+tmp[3]+tmp[4]+tmp[5]+tmp[6]+tmp[7];
135
+ for (; j < K; j++) s += row[j] * x[j];
136
+ out[i] = s;
137
+ }
138
+ }
139
+
140
+ // AVX2 dot product β€” used in attention inner loop (hs=64 β†’ 8 iters)
141
+ static inline float dot_avx2(const float* __restrict__ a,
142
+ const float* __restrict__ b, int n) {
143
+ __m256 acc = _mm256_setzero_ps();
144
+ int i = 0;
145
+ for (; i <= n-8; i += 8)
146
+ acc = _mm256_fmadd_ps(_mm256_loadu_ps(a+i),
147
+ _mm256_loadu_ps(b+i), acc);
148
+ float tmp[8]; _mm256_storeu_ps(tmp, acc);
149
+ float s = tmp[0]+tmp[1]+tmp[2]+tmp[3]+tmp[4]+tmp[5]+tmp[6]+tmp[7];
150
+ for (; i < n; i++) s += a[i]*b[i];
151
+ return s;
152
+ }
153
+
154
+ // AVX2 weighted accumulate: out += w * v
155
+ static inline void weighted_acc_avx2(float* __restrict__ out,
156
+ const float* __restrict__ v,
157
+ float w, int n) {
158
+ __m256 wv = _mm256_set1_ps(w);
159
+ int i = 0;
160
+ for (; i <= n-8; i += 8)
161
+ _mm256_storeu_ps(out+i,
162
+ _mm256_fmadd_ps(wv, _mm256_loadu_ps(v+i),
163
+ _mm256_loadu_ps(out+i)));
164
+ for (; i < n; i++) out[i] += w * v[i];
165
+ }
166
+
167
+ static inline void add_bias(float* x, const float* b, int N) {
168
+ #pragma omp parallel for
169
+ for (int i = 0; i < N; i++) x[i] += b[i];
170
+ }
171
+
172
+ static inline void residual_add(float* x, const float* y, int N) {
173
+ #pragma omp parallel for
174
+ for (int i = 0; i < N; i++) x[i] += y[i];
175
+ }
176
+
177
+ static void gelu_inplace(float* x, int N) {
178
+ const float c = 0.7978845608f;
179
+ #pragma omp parallel for
180
+ for (int i = 0; i < N; i++) {
181
+ float v = x[i];
182
+ x[i] = 0.5f*v*(1.f + tanhf(c*(v + 0.044715f*v*v*v)));
183
+ }
184
+ }
185
+
186
+ static void softmax_inplace(float* x, int N) {
187
+ float mx = x[0];
188
+ for (int i = 1; i < N; i++) if (x[i] > mx) mx = x[i];
189
+ float s = 0.f;
190
+ for (int i = 0; i < N; i++) { x[i] = expf(x[i]-mx); s += x[i]; }
191
+ for (int i = 0; i < N; i++) x[i] /= s;
192
+ }
193
+
194
+ // ─────────────────────────────────────────────────────────────────────────
195
+ // Forward Pass (single token at position pos)
196
+ // ─────────────────────────────────────────────────────────────────────────
197
+ static void forward(int token_id, int pos, float* k_cache, float* v_cache) {
198
+ const int C = cfg.n_embd;
199
+ const int H = cfg.n_head;
200
+ const int hs = C / H;
201
+
202
+ // Token + position embedding
203
+ float* te = W.wte + (long long)token_id * C;
204
+ float* pe = W.wpe + (long long)pos * C;
205
+ #pragma omp parallel for
206
+ for (int i = 0; i < C; i++) g_x[i] = te[i] + pe[i];
207
+
208
+ for (int l = 0; l < cfg.n_layer; l++) {
209
+
210
+ // ── Self-Attention ────────────────────────────────────────────────
211
+ layer_norm(g_buf, g_x, W.ln1_w[l], W.ln1_b[l], C);
212
+ matmul_vec(g_qkv, W.c_attn_w[l], g_buf, 3*C, C);
213
+ add_bias(g_qkv, W.c_attn_b[l], 3*C);
214
+
215
+ float* q = g_qkv;
216
+ float* k = g_qkv + C;
217
+ float* v = g_qkv + 2*C;
218
+ float* kc = k_cache + (long long)l * cfg.block_size * C;
219
+ float* vc = v_cache + (long long)l * cfg.block_size * C;
220
+ memcpy(kc + (long long)pos*C, k, C*sizeof(float));
221
+ memcpy(vc + (long long)pos*C, v, C*sizeof(float));
222
+
223
+ // GQA-style: each head attends, result into g_buf
224
+ #pragma omp parallel for schedule(static)
225
+ for (int h = 0; h < H; h++) {
226
+ float* qh = q + h*hs;
227
+ float scale = 1.f / sqrtf((float)hs);
228
+ float* attn = g_attn_buf + h*cfg.block_size;
229
+
230
+ // Attention scores β€” AVX2 dot
231
+ for (int t = 0; t <= pos; t++) {
232
+ float* kh = kc + (long long)t*C + h*hs;
233
+ attn[t] = dot_avx2(qh, kh, hs) * scale;
234
+ }
235
+ softmax_inplace(attn, pos+1);
236
+
237
+ // Weighted V sum β€” AVX2 accumulate
238
+ float* oh = g_buf + h*hs;
239
+ memset(oh, 0, hs*sizeof(float));
240
+ for (int t = 0; t <= pos; t++) {
241
+ float* vh = vc + (long long)t*C + h*hs;
242
+ weighted_acc_avx2(oh, vh, attn[t], hs);
243
+ }
244
+ }
245
+
246
+ // O projection into pre-allocated buffer (NO stack VLA)
247
+ matmul_vec(g_tmp_out, W.c_proj_w[l], g_buf, C, C);
248
+ add_bias(g_tmp_out, W.c_proj_b[l], C);
249
+ residual_add(g_x, g_tmp_out, C);
250
+
251
+ // ── MLP ───────────────────────────────────────────────────────────
252
+ layer_norm(g_buf, g_x, W.ln2_w[l], W.ln2_b[l], C);
253
+ matmul_vec(g_ff, W.fc_w[l], g_buf, 4*C, C);
254
+ add_bias(g_ff, W.fc_b[l], 4*C);
255
+ gelu_inplace(g_ff, 4*C);
256
+ matmul_vec(g_tmp_out, W.mlp_proj_w[l], g_ff, C, 4*C);
257
+ add_bias(g_tmp_out, W.mlp_proj_b[l], C);
258
+ residual_add(g_x, g_tmp_out, C);
259
+ }
260
+
261
+ layer_norm(g_buf, g_x, W.ln_f_w, W.ln_f_b, C);
262
+ matmul_vec(g_logits, W.lm_head_w, g_buf, cfg.vocab_size, C);
263
+ }
264
+
265
+ // ─────────────────────────────────────────────────────────────────────────
266
+ // Weight Mapping
267
+ // ─────────────────────────────────────────────────────────────────────────
268
+ static void map_weights(float* data) {
269
+ float* p = data;
270
+ const int C = cfg.n_embd, L = cfg.n_layer;
271
+ W.wte = p; p += (long long)cfg.vocab_size * C;
272
+ W.wpe = p; p += (long long)cfg.block_size * C;
273
+ #define ARR(f) W.f = (float**)malloc(L*sizeof(float*))
274
+ ARR(ln1_w); ARR(ln1_b); ARR(c_attn_w); ARR(c_attn_b);
275
+ ARR(c_proj_w); ARR(c_proj_b); ARR(ln2_w); ARR(ln2_b);
276
+ ARR(fc_w); ARR(fc_b); ARR(mlp_proj_w); ARR(mlp_proj_b);
277
+ #undef ARR
278
+ for (int l = 0; l < L; l++) {
279
+ W.ln1_w[l] = p; p += C;
280
+ W.ln1_b[l] = p; p += C;
281
+ W.c_attn_w[l] = p; p += 3LL*C*C;
282
+ W.c_attn_b[l] = p; p += 3LL*C;
283
+ W.c_proj_w[l] = p; p += 1LL*C*C;
284
+ W.c_proj_b[l] = p; p += C;
285
+ W.ln2_w[l] = p; p += C;
286
+ W.ln2_b[l] = p; p += C;
287
+ W.fc_w[l] = p; p += 4LL*C*C;
288
+ W.fc_b[l] = p; p += 4LL*C;
289
+ W.mlp_proj_w[l]= p; p += 1LL*C*4*C;
290
+ W.mlp_proj_b[l]= p; p += C;
291
+ }
292
+ W.ln_f_w = p; p += C;
293
+ W.ln_f_b = p; p += C;
294
+ W.lm_head_w = p;
295
+ }
296
+
297
+ // ─────────────────────────────────────────────────────────────────────────
298
+ // Session Management (LRU eviction)
299
+ // ─────────────────────────────────────────────────────────────────────────
300
+ static long long kv_bytes() {
301
+ return (long long)cfg.n_layer * cfg.block_size * cfg.n_embd * sizeof(float);
302
+ }
303
+
304
+ static void free_session(Session& s) {
305
+ free(s.k_cache); free(s.v_cache);
306
+ s.k_cache = nullptr; s.v_cache = nullptr; s.pos = 0;
307
+ }
308
+
309
+ static void evict_oldest() {
310
+ if (g_sessions.empty()) return;
311
+ std::string oid; double ot = 1e300;
312
+ for (auto& kv : g_sessions)
313
+ if (kv.second.last_use < ot) { ot = kv.second.last_use; oid = kv.first; }
314
+ free_session(g_sessions[oid]);
315
+ g_sessions.erase(oid);
316
+ }
317
+
318
+ static Session& get_or_create(const std::string& id) {
319
+ auto it = g_sessions.find(id);
320
+ if (it != g_sessions.end()) {
321
+ it->second.last_use = get_ms();
322
+ return it->second;
323
+ }
324
+ if ((int)g_sessions.size() >= MAX_SESSIONS) evict_oldest();
325
+ Session s;
326
+ long long nb = kv_bytes();
327
+ s.k_cache = (float*)calloc(nb, 1);
328
+ s.v_cache = (float*)calloc(nb, 1);
329
+ s.pos = 0;
330
+ s.last_use = get_ms();
331
+ g_sessions[id] = s;
332
+ return g_sessions[id];
333
+ }
334
+
335
+ // ─────────────────────────────────────────────────────────────────────────
336
+ // Top-K Sampler
337
+ // ─────────────────────────────────────────────────────────────────────────
338
+ static int sample_topk(float temperature, int top_k) {
339
+ for (int v = 0; v < cfg.vocab_size; v++) g_logits[v] /= temperature;
340
+ int K = std::min(top_k, cfg.vocab_size);
341
+ std::vector<std::pair<float,int>> pairs(cfg.vocab_size);
342
+ for (int v = 0; v < cfg.vocab_size; v++) pairs[v] = {g_logits[v], v};
343
+ std::partial_sort(pairs.begin(), pairs.begin()+K, pairs.end(),
344
+ [](const auto& a, const auto& b){ return a.first > b.first; });
345
+ float sum = 0.f;
346
+ for (int j = 0; j < K; j++) { pairs[j].first = expf(pairs[j].first); sum += pairs[j].first; }
347
+ for (int j = 0; j < K; j++) pairs[j].first /= sum;
348
+ float r = (float)rand() / ((float)RAND_MAX+1.f), cum = 0.f;
349
+ int best = pairs[0].second;
350
+ for (int j = 0; j < K; j++) { cum += pairs[j].first; if (r < cum) { best = pairs[j].second; break; } }
351
+ return best;
352
+ }
353
+
354
+ // ─────────────────────────────────────────────────────────────────────────
355
+ // Helpers
356
+ // ─────────────────────────────────────────────────────────────────────────
357
+ static std::vector<std::string> split(const std::string& s, char d) {
358
+ std::vector<std::string> out; std::string cur;
359
+ for (char c : s) { if (c==d){out.push_back(cur);cur.clear();}else cur+=c; }
360
+ out.push_back(cur); return out;
361
+ }
362
+ static std::vector<int> parse_ints(const std::string& s) {
363
+ std::vector<int> out;
364
+ for (auto& t : split(s,',')) if (!t.empty()) out.push_back(atoi(t.c_str()));
365
+ return out;
366
+ }
367
+
368
+ // ─────────────────────────────────────────────────────────────────────────
369
+ // Command Handlers
370
+ // ─────────────────────────────────────────────────────────────────────────
371
+ static void handle_request(const std::string& line) {
372
+ auto parts = split(line, '|');
373
+ if (parts.size() < 7) {
374
+ printf("ERROR bad_request_format\n"); fflush(stdout); return;
375
+ }
376
+ std::string sess_id = parts[1];
377
+ auto new_tokens = parse_ints(parts[2]);
378
+ int max_new = atoi(parts[3].c_str());
379
+ float temp = (float)atof(parts[4].c_str());
380
+ int top_k = atoi(parts[5].c_str());
381
+ auto stop_list = parse_ints(parts[6]);
382
+
383
+ temp = std::max(temp, 0.01f);
384
+ top_k = std::clamp(top_k, 1, cfg.vocab_size);
385
+ max_new = std::max(max_new, 1);
386
+
387
+ std::unordered_set<int> stop_ids(stop_list.begin(), stop_list.end());
388
+ stop_ids.insert(50256); // <|endoftext|>
389
+
390
+ Session& sess = get_or_create(sess_id);
391
+
392
+ // Prefill
393
+ for (int tok : new_tokens) {
394
+ if (sess.pos >= cfg.block_size) {
395
+ printf("ERROR context_window_full\n"); fflush(stdout); return;
396
+ }
397
+ forward(tok, sess.pos, sess.k_cache, sess.v_cache);
398
+ sess.pos++;
399
+ }
400
+
401
+ // Autoregressive generation
402
+ double t0 = get_ms();
403
+ int gen = 0;
404
+ for (int i = 0; i < max_new; i++) {
405
+ if (sess.pos >= cfg.block_size) break;
406
+ int next = sample_topk(temp, top_k);
407
+ printf("TOKEN %d %.2f\n", next, get_ms()-t0);
408
+ fflush(stdout);
409
+ gen++;
410
+ if (stop_ids.count(next)) break;
411
+ forward(next, sess.pos, sess.k_cache, sess.v_cache);
412
+ sess.pos++;
413
+ }
414
+
415
+ printf("DONE %d %.2f\n", gen, get_ms()-t0);
416
+ fflush(stdout);
417
+ }
418
+
419
+ static void handle_reset(const std::string& line) {
420
+ auto parts = split(line, '|');
421
+ if (parts.size() >= 2) {
422
+ auto it = g_sessions.find(parts[1]);
423
+ if (it != g_sessions.end()) {
424
+ free_session(it->second); g_sessions.erase(it);
425
+ }
426
+ }
427
+ printf("RESET_OK\n"); fflush(stdout);
428
+ }
429
+
430
+ // ─────────────────────────────────────────────────────────────────────────
431
+ // main
432
+ // ─────────────────────────────────────────────────────────────────────────
433
+ int main() {
434
+ FILE* f = fopen("model.bin", "rb");
435
+ if (!f) { printf("ERROR model.bin_not_found\n"); fflush(stdout); return 1; }
436
+
437
+ fread(&cfg, sizeof(int), 5, f);
438
+ fseek(f, 0, SEEK_END); long fsize = ftell(f);
439
+ fseek(f, 5*(long)sizeof(int), SEEK_SET);
440
+ long wbytes = fsize - 5*(long)sizeof(int);
441
+
442
+ g_data = (float*)malloc(wbytes);
443
+ if (!g_data) { printf("ERROR oom\n"); fflush(stdout); return 1; }
444
+ fread(g_data, 1, wbytes, f);
445
+ fclose(f);
446
+
447
+ map_weights(g_data);
448
+
449
+ const int C = cfg.n_embd;
450
+ // Pre-allocate all working buffers β€” zero stack VLAs
451
+ g_x = (float*)malloc(C * sizeof(float));
452
+ g_buf = (float*)malloc(C * sizeof(float));
453
+ g_qkv = (float*)malloc(3*C * sizeof(float));
454
+ g_attn_buf = (float*)malloc((long long)cfg.n_head * cfg.block_size * sizeof(float));
455
+ g_ff = (float*)malloc(4*C * sizeof(float));
456
+ g_logits = (float*)malloc((long long)cfg.vocab_size * sizeof(float));
457
+ g_tmp_out = (float*)malloc(C * sizeof(float)); // replaces stack VLAs
458
+
459
+ srand((unsigned)time(NULL));
460
+ printf("READY\n"); fflush(stdout);
461
+
462
+ std::string line;
463
+ while (std::getline(std::cin, line)) {
464
+ if (!line.empty() && line.back()=='\r') line.pop_back();
465
+ if (line.empty()) continue;
466
+ if (line == "QUIT") break;
467
+ else if (line.rfind("RESET|", 0)==0) handle_reset(line);
468
+ else if (line.rfind("REQUEST|", 0)==0) handle_request(line);
469
+ else { printf("ERROR unknown_cmd\n"); fflush(stdout); }
470
+ }
471
+
472
+ for (auto& kv : g_sessions) free_session(kv.second);
473
+ free(g_data);
474
+ free(g_x); free(g_buf); free(g_qkv); free(g_attn_buf);
475
+ free(g_ff); free(g_logits); free(g_tmp_out);
476
+ return 0;
477
+ }