NOT-OMEGA commited on
Commit
c11b4f4
Β·
verified Β·
1 Parent(s): ba21f66

Update inference.cpp

Browse files
Files changed (1) hide show
  1. inference.cpp +41 -73
inference.cpp CHANGED
@@ -3,44 +3,27 @@
3
  * KVInfer β€” PERSISTENT DAEMON INFERENCE ENGINE v2.0
4
  * ============================================================
5
  *
6
- * FIX #1 Persistent process: model loads ONCE at startup.
7
- * Handles unlimited requests over stdin/stdout pipe.
8
- * No more subprocess-per-request overhead.
9
- *
10
- * FIX #3 Session KV-cache reuse: each session_id keeps its
11
- * own KV cache + position. New chat turns only run
12
- * forward() on NEW tokens β€” full history stays cached.
13
- * Massive TTFT reduction on multi-turn conversations.
14
- *
15
- * FIX #4 Stop-token list: caller passes extra stop IDs (e.g.
16
- * the encoded <|user|> token) so the model cannot bleed
17
- * into the next speaker's turn.
18
- *
19
  * ── STDIN PROTOCOL ──────────────────────────────────────────
20
  * REQUEST|<sess>|<new_tokens_csv>|<max_new>|<temp>|<top_k>|<stop_csv>
21
  * RESET|<sess>
22
  * QUIT
23
  *
24
  * ── STDOUT PROTOCOL ─────────────────────────────────────────
25
- * READY (once, after model loads)
26
- * TOKEN <id> <elapsed_ms> (one per generated token)
27
- * DONE <count> <total_ms> (end of one request)
28
- * RESET_OK (ack for RESET)
29
  * ERROR <message>
30
  *
31
- * ── COMPILE (MSVC, Developer Prompt) ────────────────────────
32
- * cl /O2 /openmp /arch:AVX2 /fp:fast /std:c++17 /EHsc /Fe:inference.exe inference.cpp
33
- *
34
- * ── COMPILE (GCC / MinGW) ───────────────────────────────────
35
- * g++ -O3 -march=native -fopenmp -ffast-math -std=c++17 -o inference.exe inference.cpp
36
  * ============================================================
37
  */
38
-
39
  #include <stdio.h>
40
  #include <stdlib.h>
41
  #include <math.h>
42
  #include <string.h>
43
- #include <iostream>
44
  #include <time.h>
45
  #include <algorithm>
46
  #include <string>
@@ -48,11 +31,9 @@
48
  #include <unordered_set>
49
  #include <vector>
50
  #include <immintrin.h> // AVX2 + FMA
51
-
52
  #ifdef _OPENMP
53
  #include <omp.h>
54
  #endif
55
-
56
  #ifdef _WIN32
57
  #include <windows.h>
58
  static double get_time_ms() {
@@ -72,9 +53,7 @@
72
  // ─────────────────────────────────────────────────────────────────────────
73
  // Model Structures
74
  // ─────────────────────────────────────────────────────────────────────────
75
-
76
  typedef struct { int n_layer, n_head, n_embd, block_size, vocab_size; } Config;
77
-
78
  typedef struct {
79
  float *wte, *wpe;
80
  float **ln1_w, **ln1_b;
@@ -90,7 +69,7 @@ typedef struct {
90
  struct SessionState {
91
  float* k_cache = nullptr;
92
  float* v_cache = nullptr;
93
- int pos = 0; // tokens already in KV cache
94
  double last_used = 0.0;
95
  };
96
 
@@ -98,17 +77,20 @@ static Config cfg;
98
  static Weights W;
99
  static float* g_model_data = nullptr;
100
 
101
- // LRU session store
102
- static const int MAX_SESSIONS = 4;
 
 
 
 
103
  static std::unordered_map<std::string, SessionState> g_sessions;
104
 
105
  // Shared per-request working buffers
106
  static float *g_x, *g_buf, *g_qkv, *g_attn, *g_ff, *g_logits;
107
 
108
  // ─────────────────────────────────────────────────────────────────────────
109
- // Math Kernels (AVX2 + FMA + OpenMP)
110
  // ─────────────────────────────────────────────────────────────────────────
111
-
112
  static void layer_norm(float* out, const float* x, const float* w,
113
  const float* b, int N) {
114
  float mean = 0.f, var = 0.f;
@@ -165,20 +147,16 @@ static void softmax_inplace(float* x, int N) {
165
  }
166
 
167
  // ─────────────────────────────────────────────────────────────────────────
168
- // Transformer Forward (single token at position `pos`)
169
- // Writes next-token log-probs into g_logits.
170
  // ─────────────────────────────────────────────────────────────────────────
171
-
172
  static void forward(int token_id, int pos, float* k_cache, float* v_cache) {
173
  const int C = cfg.n_embd, H = cfg.n_head, hs = C/H;
174
-
175
  float* te = W.wte + (long long)token_id*C;
176
  float* pe = W.wpe + (long long)pos*C;
177
  #pragma omp parallel for
178
  for (int i = 0; i < C; i++) g_x[i] = te[i]+pe[i];
179
 
180
  for (int l = 0; l < cfg.n_layer; l++) {
181
-
182
  // Self-attention
183
  layer_norm(g_buf, g_x, W.ln1_w[l], W.ln1_b[l], C);
184
  matmul_vec(g_qkv, W.c_attn_w[l], g_buf, 3*C, C);
@@ -233,34 +211,30 @@ static void forward(int token_id, int pos, float* k_cache, float* v_cache) {
233
  // ─────────────────────────────────────────────────────────────────────────
234
  // Weight Mapping
235
  // ─────────────────────────────────────────────────────────────────────────
236
-
237
  static void map_weights(float* data) {
238
  float* p = data;
239
  const int C = cfg.n_embd, L = cfg.n_layer;
240
  W.wte=p; p+=(long long)cfg.vocab_size*C;
241
  W.wpe=p; p+=(long long)cfg.block_size*C;
242
-
243
  #define ARR(f) W.f=(float**)malloc(L*sizeof(float*))
244
  ARR(ln1_w); ARR(ln1_b); ARR(c_attn_w); ARR(c_attn_b);
245
  ARR(c_proj_w); ARR(c_proj_b); ARR(ln2_w); ARR(ln2_b);
246
  ARR(fc_w); ARR(fc_b); ARR(mlp_proj_w); ARR(mlp_proj_b);
247
  #undef ARR
248
-
249
  for (int l = 0; l < L; l++) {
250
- W.ln1_w[l]=p; p+=C; W.ln1_b[l]=p; p+=C;
251
  W.c_attn_w[l]=p; p+=3LL*C*C; W.c_attn_b[l]=p; p+=3LL*C;
252
  W.c_proj_w[l]=p; p+=1LL*C*C; W.c_proj_b[l]=p; p+=C;
253
- W.ln2_w[l]=p; p+=C; W.ln2_b[l]=p; p+=C;
254
- W.fc_w[l]=p; p+=4LL*C*C; W.fc_b[l]=p; p+=4LL*C;
255
  W.mlp_proj_w[l]=p; p+=1LL*C*4*C; W.mlp_proj_b[l]=p; p+=C;
256
  }
257
  W.ln_f_w=p; p+=C; W.ln_f_b=p; p+=C; W.lm_head_w=p;
258
  }
259
 
260
  // ─────────────────────────────────────────────────────────────────────────
261
- // Session Management (LRU, max MAX_SESSIONS)
262
  // ─────────────────────────────────────────────────────────────────────────
263
-
264
  static long long kv_alloc_bytes() {
265
  return (long long)cfg.n_layer * cfg.block_size * cfg.n_embd * sizeof(float);
266
  }
@@ -297,16 +271,16 @@ static SessionState& get_or_create(const std::string& id) {
297
  }
298
 
299
  // ─────────────────────────────────────────────────────────────────────────
300
- // Sampler
301
  // ─────────────────────────────────────────────────────────────────────────
302
-
303
  static int sample_topk(float temperature, int top_k) {
304
  for (int v = 0; v < cfg.vocab_size; v++) g_logits[v] /= temperature;
305
  std::vector<std::pair<float,int>> pairs(cfg.vocab_size);
306
  for (int v = 0; v < cfg.vocab_size; v++) pairs[v]={g_logits[v],v};
307
  std::partial_sort(pairs.begin(), pairs.begin()+top_k, pairs.end(),
308
- [](const std::pair<float,int>& a,const std::pair<float,int>& b){
309
- return a.first>b.first;});
 
310
  float sum=0.f;
311
  for (int j=0; j<top_k; j++) { pairs[j].first=expf(pairs[j].first); sum+=pairs[j].first; }
312
  for (int j=0; j<top_k; j++) pairs[j].first /= sum;
@@ -319,7 +293,6 @@ static int sample_topk(float temperature, int top_k) {
319
  // ─────────────────────────────────────────��───────────────────────────────
320
  // Helpers
321
  // ─────────────────────────────────────────────────────────────────────────
322
-
323
  static std::vector<std::string> split(const std::string& s, char d) {
324
  std::vector<std::string> out; std::string cur;
325
  for (char c:s){ if(c==d){out.push_back(cur);cur.clear();}else cur+=c; }
@@ -336,19 +309,17 @@ static std::vector<int> parse_ints(const std::string& s) {
336
  // ─────────────────────────────────────────────────────────────────────────
337
  // Command Handlers
338
  // ─────────────────────────────────────────────────────────────────────────
339
-
340
- // REQUEST|<sess>|<new_tokens_csv>|<max_new>|<temp>|<top_k>|<stop_csv>
341
  static void handle_request(const std::string& line) {
342
  auto parts = split(line, '|');
343
  if (parts.size() < 7) {
344
  printf("ERROR bad_request_format\n"); fflush(stdout); return;
345
  }
346
- std::string sess_id = parts[1];
347
- auto new_tokens = parse_ints(parts[2]);
348
- int max_new = atoi(parts[3].c_str());
349
- float temp = (float)atof(parts[4].c_str());
350
- int top_k = atoi(parts[5].c_str());
351
- auto stop_list = parse_ints(parts[6]);
352
 
353
  if (temp < 0.01f) temp = 0.01f;
354
  if (top_k < 1) top_k = 1;
@@ -356,11 +327,11 @@ static void handle_request(const std::string& line) {
356
  if (max_new < 1) max_new = 1;
357
 
358
  std::unordered_set<int> stop_ids(stop_list.begin(), stop_list.end());
359
- stop_ids.insert(50256); // <|endoftext|> always a stop
360
 
361
  SessionState& sess = get_or_create(sess_id);
362
 
363
- // ── Prefill new tokens (updates session KV cache) ─────────────────
364
  for (int tok : new_tokens) {
365
  if (sess.pos >= cfg.block_size) {
366
  printf("ERROR context_window_full\n"); fflush(stdout); return;
@@ -369,10 +340,9 @@ static void handle_request(const std::string& line) {
369
  sess.pos++;
370
  }
371
 
372
- // ── Autoregressive generation ─────────────────────────────────────
373
  double t0 = get_time_ms();
374
  int gen = 0;
375
-
376
  for (int i = 0; i < max_new; i++) {
377
  if (sess.pos >= cfg.block_size) break;
378
  int best = sample_topk(temp, top_k);
@@ -388,7 +358,6 @@ static void handle_request(const std::string& line) {
388
  fflush(stdout);
389
  }
390
 
391
- // RESET|<sess>
392
  static void handle_reset(const std::string& line) {
393
  auto parts = split(line, '|');
394
  if (parts.size() < 2) { printf("RESET_OK\n"); fflush(stdout); return; }
@@ -401,9 +370,8 @@ static void handle_reset(const std::string& line) {
401
  }
402
 
403
  // ─────────────────────────────────────────────────────────────────────────
404
- // MAIN β€” load model once, then serve from stdin forever
405
  // ─────────────────────────────────────────────────────────────────────────
406
-
407
  int main() {
408
  FILE* f = fopen("model.bin", "rb");
409
  if (!f) { printf("ERROR model.bin_not_found\n"); fflush(stdout); return 1; }
@@ -412,12 +380,13 @@ int main() {
412
  fseek(f, 0, SEEK_END);
413
  long fsize = ftell(f);
414
  fseek(f, 5*(long)sizeof(int), SEEK_SET);
415
-
416
  long wbytes = fsize - 5*(long)sizeof(int);
 
417
  g_model_data = (float*)malloc(wbytes);
418
  if (!g_model_data) { printf("ERROR oom_loading_model\n"); fflush(stdout); return 1; }
419
  fread(g_model_data, 1, wbytes, f);
420
  fclose(f);
 
421
  map_weights(g_model_data);
422
 
423
  const int C = cfg.n_embd;
@@ -429,16 +398,15 @@ int main() {
429
  g_logits = (float*)malloc(cfg.vocab_size*sizeof(float));
430
 
431
  srand((unsigned int)time(NULL));
432
-
433
- printf("READY\n"); fflush(stdout); // Python waits for this
434
 
435
  std::string line;
436
  while (std::getline(std::cin, line)) {
437
  if (!line.empty() && line.back()=='\r') line.pop_back();
438
  if (line.empty()) continue;
439
- if (line == "QUIT") break;
440
- else if (line.rfind("RESET|",0)==0) handle_reset(line);
441
- else if (line.rfind("REQUEST|",0)==0) handle_request(line);
442
  else { printf("ERROR unknown_cmd\n"); fflush(stdout); }
443
  }
444
 
@@ -446,4 +414,4 @@ int main() {
446
  free(g_model_data);
447
  free(g_x); free(g_buf); free(g_qkv); free(g_attn); free(g_ff); free(g_logits);
448
  return 0;
449
- }
 
3
  * KVInfer β€” PERSISTENT DAEMON INFERENCE ENGINE v2.0
4
  * ============================================================
5
  *
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  * ── STDIN PROTOCOL ──────────────────────────────────────────
7
  * REQUEST|<sess>|<new_tokens_csv>|<max_new>|<temp>|<top_k>|<stop_csv>
8
  * RESET|<sess>
9
  * QUIT
10
  *
11
  * ── STDOUT PROTOCOL ─────────────────────────────────────────
12
+ * READY
13
+ * TOKEN <id> <elapsed_ms>
14
+ * DONE <count> <total_ms>
15
+ * RESET_OK
16
  * ERROR <message>
17
  *
18
+ * ── COMPILE (GCC / Linux) ───────────────────────────────────
19
+ * g++ -O3 -march=native -fopenmp -ffast-math -std=c++17 -o inference inference.cpp
 
 
 
20
  * ============================================================
21
  */
 
22
  #include <stdio.h>
23
  #include <stdlib.h>
24
  #include <math.h>
25
  #include <string.h>
26
+ #include <iostream>
27
  #include <time.h>
28
  #include <algorithm>
29
  #include <string>
 
31
  #include <unordered_set>
32
  #include <vector>
33
  #include <immintrin.h> // AVX2 + FMA
 
34
  #ifdef _OPENMP
35
  #include <omp.h>
36
  #endif
 
37
  #ifdef _WIN32
38
  #include <windows.h>
39
  static double get_time_ms() {
 
53
  // ─────────────────────────────────────────────────────────────────────────
54
  // Model Structures
55
  // ─────────────────────────────────────────────────────────────────────────
 
56
  typedef struct { int n_layer, n_head, n_embd, block_size, vocab_size; } Config;
 
57
  typedef struct {
58
  float *wte, *wpe;
59
  float **ln1_w, **ln1_b;
 
69
  struct SessionState {
70
  float* k_cache = nullptr;
71
  float* v_cache = nullptr;
72
+ int pos = 0;
73
  double last_used = 0.0;
74
  };
75
 
 
77
  static Weights W;
78
  static float* g_model_data = nullptr;
79
 
80
+ // ─────────────────────────────────────────────────────────────────────────
81
+ // MAX_SESSIONS β€” 3 engines Γ— 14 sessions Γ— 96MB = ~4GB KV cache
82
+ // Total RAM: ~6.57GB (safe under HF 8GB)
83
+ // ─────────────────────────────────────────────────────────────────────────
84
+ static const int MAX_SESSIONS = 14;
85
+
86
  static std::unordered_map<std::string, SessionState> g_sessions;
87
 
88
  // Shared per-request working buffers
89
  static float *g_x, *g_buf, *g_qkv, *g_attn, *g_ff, *g_logits;
90
 
91
  // ─────────────────────────────────────────────────────────────────────────
92
+ // Math Kernels (AVX2 + FMA + OpenMP)
93
  // ─────────────────────────────────────────────────────────────────────────
 
94
  static void layer_norm(float* out, const float* x, const float* w,
95
  const float* b, int N) {
96
  float mean = 0.f, var = 0.f;
 
147
  }
148
 
149
  // ─────────────────────────────────────────────────────────────────────────
150
+ // Transformer Forward (single token at position `pos`)
 
151
  // ─────────────────────────────────────────────────────────────────────────
 
152
  static void forward(int token_id, int pos, float* k_cache, float* v_cache) {
153
  const int C = cfg.n_embd, H = cfg.n_head, hs = C/H;
 
154
  float* te = W.wte + (long long)token_id*C;
155
  float* pe = W.wpe + (long long)pos*C;
156
  #pragma omp parallel for
157
  for (int i = 0; i < C; i++) g_x[i] = te[i]+pe[i];
158
 
159
  for (int l = 0; l < cfg.n_layer; l++) {
 
160
  // Self-attention
161
  layer_norm(g_buf, g_x, W.ln1_w[l], W.ln1_b[l], C);
162
  matmul_vec(g_qkv, W.c_attn_w[l], g_buf, 3*C, C);
 
211
  // ─────────────────────────────────────────────────────────────────────────
212
  // Weight Mapping
213
  // ─────────────────────────────────────────────────────────────────────────
 
214
  static void map_weights(float* data) {
215
  float* p = data;
216
  const int C = cfg.n_embd, L = cfg.n_layer;
217
  W.wte=p; p+=(long long)cfg.vocab_size*C;
218
  W.wpe=p; p+=(long long)cfg.block_size*C;
 
219
  #define ARR(f) W.f=(float**)malloc(L*sizeof(float*))
220
  ARR(ln1_w); ARR(ln1_b); ARR(c_attn_w); ARR(c_attn_b);
221
  ARR(c_proj_w); ARR(c_proj_b); ARR(ln2_w); ARR(ln2_b);
222
  ARR(fc_w); ARR(fc_b); ARR(mlp_proj_w); ARR(mlp_proj_b);
223
  #undef ARR
 
224
  for (int l = 0; l < L; l++) {
225
+ W.ln1_w[l]=p; p+=C; W.ln1_b[l]=p; p+=C;
226
  W.c_attn_w[l]=p; p+=3LL*C*C; W.c_attn_b[l]=p; p+=3LL*C;
227
  W.c_proj_w[l]=p; p+=1LL*C*C; W.c_proj_b[l]=p; p+=C;
228
+ W.ln2_w[l]=p; p+=C; W.ln2_b[l]=p; p+=C;
229
+ W.fc_w[l]=p; p+=4LL*C*C; W.fc_b[l]=p; p+=4LL*C;
230
  W.mlp_proj_w[l]=p; p+=1LL*C*4*C; W.mlp_proj_b[l]=p; p+=C;
231
  }
232
  W.ln_f_w=p; p+=C; W.ln_f_b=p; p+=C; W.lm_head_w=p;
233
  }
234
 
235
  // ─────────────────────────────────────────────────────────────────────────
236
+ // Session Management (LRU eviction when MAX_SESSIONS reached)
237
  // ─────────────────────────────────────────────────────────────────────────
 
238
  static long long kv_alloc_bytes() {
239
  return (long long)cfg.n_layer * cfg.block_size * cfg.n_embd * sizeof(float);
240
  }
 
271
  }
272
 
273
  // ─────────────────────────────────────────────────────────────────────────
274
+ // Sampler (Top-K)
275
  // ─────────────────────────────────────────────────────────────────────────
 
276
  static int sample_topk(float temperature, int top_k) {
277
  for (int v = 0; v < cfg.vocab_size; v++) g_logits[v] /= temperature;
278
  std::vector<std::pair<float,int>> pairs(cfg.vocab_size);
279
  for (int v = 0; v < cfg.vocab_size; v++) pairs[v]={g_logits[v],v};
280
  std::partial_sort(pairs.begin(), pairs.begin()+top_k, pairs.end(),
281
+ [](const std::pair<float,int>& a, const std::pair<float,int>& b){
282
+ return a.first > b.first;
283
+ });
284
  float sum=0.f;
285
  for (int j=0; j<top_k; j++) { pairs[j].first=expf(pairs[j].first); sum+=pairs[j].first; }
286
  for (int j=0; j<top_k; j++) pairs[j].first /= sum;
 
293
  // ─────────────────────────────────────────��───────────────────────────────
294
  // Helpers
295
  // ─────────────────────────────────────────────────────────────────────────
 
296
  static std::vector<std::string> split(const std::string& s, char d) {
297
  std::vector<std::string> out; std::string cur;
298
  for (char c:s){ if(c==d){out.push_back(cur);cur.clear();}else cur+=c; }
 
309
  // ─────────────────────────────────────────────────────────────────────────
310
  // Command Handlers
311
  // ─────────────────────────────────────────────────────────────────────────
 
 
312
  static void handle_request(const std::string& line) {
313
  auto parts = split(line, '|');
314
  if (parts.size() < 7) {
315
  printf("ERROR bad_request_format\n"); fflush(stdout); return;
316
  }
317
+ std::string sess_id = parts[1];
318
+ auto new_tokens = parse_ints(parts[2]);
319
+ int max_new = atoi(parts[3].c_str());
320
+ float temp = (float)atof(parts[4].c_str());
321
+ int top_k = atoi(parts[5].c_str());
322
+ auto stop_list = parse_ints(parts[6]);
323
 
324
  if (temp < 0.01f) temp = 0.01f;
325
  if (top_k < 1) top_k = 1;
 
327
  if (max_new < 1) max_new = 1;
328
 
329
  std::unordered_set<int> stop_ids(stop_list.begin(), stop_list.end());
330
+ stop_ids.insert(50256); // <|endoftext|> always stop
331
 
332
  SessionState& sess = get_or_create(sess_id);
333
 
334
+ // Prefill new tokens into KV cache
335
  for (int tok : new_tokens) {
336
  if (sess.pos >= cfg.block_size) {
337
  printf("ERROR context_window_full\n"); fflush(stdout); return;
 
340
  sess.pos++;
341
  }
342
 
343
+ // Autoregressive generation
344
  double t0 = get_time_ms();
345
  int gen = 0;
 
346
  for (int i = 0; i < max_new; i++) {
347
  if (sess.pos >= cfg.block_size) break;
348
  int best = sample_topk(temp, top_k);
 
358
  fflush(stdout);
359
  }
360
 
 
361
  static void handle_reset(const std::string& line) {
362
  auto parts = split(line, '|');
363
  if (parts.size() < 2) { printf("RESET_OK\n"); fflush(stdout); return; }
 
370
  }
371
 
372
  // ─────────────────────────────────────────────────────────────────────────
373
+ // MAIN β€” model ek baar load, phir stdin se commands serve karo
374
  // ─────────────────────────────────────────────────────────────────────────
 
375
  int main() {
376
  FILE* f = fopen("model.bin", "rb");
377
  if (!f) { printf("ERROR model.bin_not_found\n"); fflush(stdout); return 1; }
 
380
  fseek(f, 0, SEEK_END);
381
  long fsize = ftell(f);
382
  fseek(f, 5*(long)sizeof(int), SEEK_SET);
 
383
  long wbytes = fsize - 5*(long)sizeof(int);
384
+
385
  g_model_data = (float*)malloc(wbytes);
386
  if (!g_model_data) { printf("ERROR oom_loading_model\n"); fflush(stdout); return 1; }
387
  fread(g_model_data, 1, wbytes, f);
388
  fclose(f);
389
+
390
  map_weights(g_model_data);
391
 
392
  const int C = cfg.n_embd;
 
398
  g_logits = (float*)malloc(cfg.vocab_size*sizeof(float));
399
 
400
  srand((unsigned int)time(NULL));
401
+ printf("READY\n"); fflush(stdout); // Python waits for this
 
402
 
403
  std::string line;
404
  while (std::getline(std::cin, line)) {
405
  if (!line.empty() && line.back()=='\r') line.pop_back();
406
  if (line.empty()) continue;
407
+ if (line == "QUIT") break;
408
+ else if (line.rfind("RESET|",0)==0) handle_reset(line);
409
+ else if (line.rfind("REQUEST|",0)==0) handle_request(line);
410
  else { printf("ERROR unknown_cmd\n"); fflush(stdout); }
411
  }
412
 
 
414
  free(g_model_data);
415
  free(g_x); free(g_buf); free(g_qkv); free(g_attn); free(g_ff); free(g_logits);
416
  return 0;
417
+ }