Update inference.cpp
Browse files- inference.cpp +41 -73
inference.cpp
CHANGED
|
@@ -3,44 +3,27 @@
|
|
| 3 |
* KVInfer β PERSISTENT DAEMON INFERENCE ENGINE v2.0
|
| 4 |
* ============================================================
|
| 5 |
*
|
| 6 |
-
* FIX #1 Persistent process: model loads ONCE at startup.
|
| 7 |
-
* Handles unlimited requests over stdin/stdout pipe.
|
| 8 |
-
* No more subprocess-per-request overhead.
|
| 9 |
-
*
|
| 10 |
-
* FIX #3 Session KV-cache reuse: each session_id keeps its
|
| 11 |
-
* own KV cache + position. New chat turns only run
|
| 12 |
-
* forward() on NEW tokens β full history stays cached.
|
| 13 |
-
* Massive TTFT reduction on multi-turn conversations.
|
| 14 |
-
*
|
| 15 |
-
* FIX #4 Stop-token list: caller passes extra stop IDs (e.g.
|
| 16 |
-
* the encoded <|user|> token) so the model cannot bleed
|
| 17 |
-
* into the next speaker's turn.
|
| 18 |
-
*
|
| 19 |
* ββ STDIN PROTOCOL ββββββββββββββββββββββββββββββββββββββββββ
|
| 20 |
* REQUEST|<sess>|<new_tokens_csv>|<max_new>|<temp>|<top_k>|<stop_csv>
|
| 21 |
* RESET|<sess>
|
| 22 |
* QUIT
|
| 23 |
*
|
| 24 |
* ββ STDOUT PROTOCOL βββββββββββββββββββββββββββββββββββββββββ
|
| 25 |
-
* READY
|
| 26 |
-
* TOKEN <id> <elapsed_ms>
|
| 27 |
-
* DONE <count> <total_ms>
|
| 28 |
-
* RESET_OK
|
| 29 |
* ERROR <message>
|
| 30 |
*
|
| 31 |
-
* ββ COMPILE (
|
| 32 |
-
*
|
| 33 |
-
*
|
| 34 |
-
* ββ COMPILE (GCC / MinGW) βββββββββββββββββββββββββββββββββββ
|
| 35 |
-
* g++ -O3 -march=native -fopenmp -ffast-math -std=c++17 -o inference.exe inference.cpp
|
| 36 |
* ============================================================
|
| 37 |
*/
|
| 38 |
-
|
| 39 |
#include <stdio.h>
|
| 40 |
#include <stdlib.h>
|
| 41 |
#include <math.h>
|
| 42 |
#include <string.h>
|
| 43 |
-
#include <iostream>
|
| 44 |
#include <time.h>
|
| 45 |
#include <algorithm>
|
| 46 |
#include <string>
|
|
@@ -48,11 +31,9 @@
|
|
| 48 |
#include <unordered_set>
|
| 49 |
#include <vector>
|
| 50 |
#include <immintrin.h> // AVX2 + FMA
|
| 51 |
-
|
| 52 |
#ifdef _OPENMP
|
| 53 |
#include <omp.h>
|
| 54 |
#endif
|
| 55 |
-
|
| 56 |
#ifdef _WIN32
|
| 57 |
#include <windows.h>
|
| 58 |
static double get_time_ms() {
|
|
@@ -72,9 +53,7 @@
|
|
| 72 |
// βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 73 |
// Model Structures
|
| 74 |
// βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 75 |
-
|
| 76 |
typedef struct { int n_layer, n_head, n_embd, block_size, vocab_size; } Config;
|
| 77 |
-
|
| 78 |
typedef struct {
|
| 79 |
float *wte, *wpe;
|
| 80 |
float **ln1_w, **ln1_b;
|
|
@@ -90,7 +69,7 @@ typedef struct {
|
|
| 90 |
struct SessionState {
|
| 91 |
float* k_cache = nullptr;
|
| 92 |
float* v_cache = nullptr;
|
| 93 |
-
int pos = 0;
|
| 94 |
double last_used = 0.0;
|
| 95 |
};
|
| 96 |
|
|
@@ -98,17 +77,20 @@ static Config cfg;
|
|
| 98 |
static Weights W;
|
| 99 |
static float* g_model_data = nullptr;
|
| 100 |
|
| 101 |
-
//
|
| 102 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 103 |
static std::unordered_map<std::string, SessionState> g_sessions;
|
| 104 |
|
| 105 |
// Shared per-request working buffers
|
| 106 |
static float *g_x, *g_buf, *g_qkv, *g_attn, *g_ff, *g_logits;
|
| 107 |
|
| 108 |
// βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 109 |
-
// Math Kernels
|
| 110 |
// βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 111 |
-
|
| 112 |
static void layer_norm(float* out, const float* x, const float* w,
|
| 113 |
const float* b, int N) {
|
| 114 |
float mean = 0.f, var = 0.f;
|
|
@@ -165,20 +147,16 @@ static void softmax_inplace(float* x, int N) {
|
|
| 165 |
}
|
| 166 |
|
| 167 |
// βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 168 |
-
// Transformer Forward
|
| 169 |
-
// Writes next-token log-probs into g_logits.
|
| 170 |
// βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 171 |
-
|
| 172 |
static void forward(int token_id, int pos, float* k_cache, float* v_cache) {
|
| 173 |
const int C = cfg.n_embd, H = cfg.n_head, hs = C/H;
|
| 174 |
-
|
| 175 |
float* te = W.wte + (long long)token_id*C;
|
| 176 |
float* pe = W.wpe + (long long)pos*C;
|
| 177 |
#pragma omp parallel for
|
| 178 |
for (int i = 0; i < C; i++) g_x[i] = te[i]+pe[i];
|
| 179 |
|
| 180 |
for (int l = 0; l < cfg.n_layer; l++) {
|
| 181 |
-
|
| 182 |
// Self-attention
|
| 183 |
layer_norm(g_buf, g_x, W.ln1_w[l], W.ln1_b[l], C);
|
| 184 |
matmul_vec(g_qkv, W.c_attn_w[l], g_buf, 3*C, C);
|
|
@@ -233,34 +211,30 @@ static void forward(int token_id, int pos, float* k_cache, float* v_cache) {
|
|
| 233 |
// βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 234 |
// Weight Mapping
|
| 235 |
// βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 236 |
-
|
| 237 |
static void map_weights(float* data) {
|
| 238 |
float* p = data;
|
| 239 |
const int C = cfg.n_embd, L = cfg.n_layer;
|
| 240 |
W.wte=p; p+=(long long)cfg.vocab_size*C;
|
| 241 |
W.wpe=p; p+=(long long)cfg.block_size*C;
|
| 242 |
-
|
| 243 |
#define ARR(f) W.f=(float**)malloc(L*sizeof(float*))
|
| 244 |
ARR(ln1_w); ARR(ln1_b); ARR(c_attn_w); ARR(c_attn_b);
|
| 245 |
ARR(c_proj_w); ARR(c_proj_b); ARR(ln2_w); ARR(ln2_b);
|
| 246 |
ARR(fc_w); ARR(fc_b); ARR(mlp_proj_w); ARR(mlp_proj_b);
|
| 247 |
#undef ARR
|
| 248 |
-
|
| 249 |
for (int l = 0; l < L; l++) {
|
| 250 |
-
W.ln1_w[l]=p; p+=C;
|
| 251 |
W.c_attn_w[l]=p; p+=3LL*C*C; W.c_attn_b[l]=p; p+=3LL*C;
|
| 252 |
W.c_proj_w[l]=p; p+=1LL*C*C; W.c_proj_b[l]=p; p+=C;
|
| 253 |
-
W.ln2_w[l]=p; p+=C;
|
| 254 |
-
W.fc_w[l]=p; p+=4LL*C*C;
|
| 255 |
W.mlp_proj_w[l]=p; p+=1LL*C*4*C; W.mlp_proj_b[l]=p; p+=C;
|
| 256 |
}
|
| 257 |
W.ln_f_w=p; p+=C; W.ln_f_b=p; p+=C; W.lm_head_w=p;
|
| 258 |
}
|
| 259 |
|
| 260 |
// βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 261 |
-
// Session Management
|
| 262 |
// βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 263 |
-
|
| 264 |
static long long kv_alloc_bytes() {
|
| 265 |
return (long long)cfg.n_layer * cfg.block_size * cfg.n_embd * sizeof(float);
|
| 266 |
}
|
|
@@ -297,16 +271,16 @@ static SessionState& get_or_create(const std::string& id) {
|
|
| 297 |
}
|
| 298 |
|
| 299 |
// βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 300 |
-
// Sampler
|
| 301 |
// βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 302 |
-
|
| 303 |
static int sample_topk(float temperature, int top_k) {
|
| 304 |
for (int v = 0; v < cfg.vocab_size; v++) g_logits[v] /= temperature;
|
| 305 |
std::vector<std::pair<float,int>> pairs(cfg.vocab_size);
|
| 306 |
for (int v = 0; v < cfg.vocab_size; v++) pairs[v]={g_logits[v],v};
|
| 307 |
std::partial_sort(pairs.begin(), pairs.begin()+top_k, pairs.end(),
|
| 308 |
-
[](const std::pair<float,int>& a,const std::pair<float,int>& b){
|
| 309 |
-
return a.first>b.first;
|
|
|
|
| 310 |
float sum=0.f;
|
| 311 |
for (int j=0; j<top_k; j++) { pairs[j].first=expf(pairs[j].first); sum+=pairs[j].first; }
|
| 312 |
for (int j=0; j<top_k; j++) pairs[j].first /= sum;
|
|
@@ -319,7 +293,6 @@ static int sample_topk(float temperature, int top_k) {
|
|
| 319 |
// βββββββββββββββββββββββββββββββββββββββββοΏ½οΏ½βββββββββββββββββββββββββββββββ
|
| 320 |
// Helpers
|
| 321 |
// βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 322 |
-
|
| 323 |
static std::vector<std::string> split(const std::string& s, char d) {
|
| 324 |
std::vector<std::string> out; std::string cur;
|
| 325 |
for (char c:s){ if(c==d){out.push_back(cur);cur.clear();}else cur+=c; }
|
|
@@ -336,19 +309,17 @@ static std::vector<int> parse_ints(const std::string& s) {
|
|
| 336 |
// βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 337 |
// Command Handlers
|
| 338 |
// βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 339 |
-
|
| 340 |
-
// REQUEST|<sess>|<new_tokens_csv>|<max_new>|<temp>|<top_k>|<stop_csv>
|
| 341 |
static void handle_request(const std::string& line) {
|
| 342 |
auto parts = split(line, '|');
|
| 343 |
if (parts.size() < 7) {
|
| 344 |
printf("ERROR bad_request_format\n"); fflush(stdout); return;
|
| 345 |
}
|
| 346 |
-
std::string sess_id
|
| 347 |
-
auto new_tokens
|
| 348 |
-
int max_new
|
| 349 |
-
float temp
|
| 350 |
-
int top_k
|
| 351 |
-
auto stop_list
|
| 352 |
|
| 353 |
if (temp < 0.01f) temp = 0.01f;
|
| 354 |
if (top_k < 1) top_k = 1;
|
|
@@ -356,11 +327,11 @@ static void handle_request(const std::string& line) {
|
|
| 356 |
if (max_new < 1) max_new = 1;
|
| 357 |
|
| 358 |
std::unordered_set<int> stop_ids(stop_list.begin(), stop_list.end());
|
| 359 |
-
stop_ids.insert(50256);
|
| 360 |
|
| 361 |
SessionState& sess = get_or_create(sess_id);
|
| 362 |
|
| 363 |
-
//
|
| 364 |
for (int tok : new_tokens) {
|
| 365 |
if (sess.pos >= cfg.block_size) {
|
| 366 |
printf("ERROR context_window_full\n"); fflush(stdout); return;
|
|
@@ -369,10 +340,9 @@ static void handle_request(const std::string& line) {
|
|
| 369 |
sess.pos++;
|
| 370 |
}
|
| 371 |
|
| 372 |
-
//
|
| 373 |
double t0 = get_time_ms();
|
| 374 |
int gen = 0;
|
| 375 |
-
|
| 376 |
for (int i = 0; i < max_new; i++) {
|
| 377 |
if (sess.pos >= cfg.block_size) break;
|
| 378 |
int best = sample_topk(temp, top_k);
|
|
@@ -388,7 +358,6 @@ static void handle_request(const std::string& line) {
|
|
| 388 |
fflush(stdout);
|
| 389 |
}
|
| 390 |
|
| 391 |
-
// RESET|<sess>
|
| 392 |
static void handle_reset(const std::string& line) {
|
| 393 |
auto parts = split(line, '|');
|
| 394 |
if (parts.size() < 2) { printf("RESET_OK\n"); fflush(stdout); return; }
|
|
@@ -401,9 +370,8 @@ static void handle_reset(const std::string& line) {
|
|
| 401 |
}
|
| 402 |
|
| 403 |
// βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 404 |
-
// MAIN β
|
| 405 |
// βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 406 |
-
|
| 407 |
int main() {
|
| 408 |
FILE* f = fopen("model.bin", "rb");
|
| 409 |
if (!f) { printf("ERROR model.bin_not_found\n"); fflush(stdout); return 1; }
|
|
@@ -412,12 +380,13 @@ int main() {
|
|
| 412 |
fseek(f, 0, SEEK_END);
|
| 413 |
long fsize = ftell(f);
|
| 414 |
fseek(f, 5*(long)sizeof(int), SEEK_SET);
|
| 415 |
-
|
| 416 |
long wbytes = fsize - 5*(long)sizeof(int);
|
|
|
|
| 417 |
g_model_data = (float*)malloc(wbytes);
|
| 418 |
if (!g_model_data) { printf("ERROR oom_loading_model\n"); fflush(stdout); return 1; }
|
| 419 |
fread(g_model_data, 1, wbytes, f);
|
| 420 |
fclose(f);
|
|
|
|
| 421 |
map_weights(g_model_data);
|
| 422 |
|
| 423 |
const int C = cfg.n_embd;
|
|
@@ -429,16 +398,15 @@ int main() {
|
|
| 429 |
g_logits = (float*)malloc(cfg.vocab_size*sizeof(float));
|
| 430 |
|
| 431 |
srand((unsigned int)time(NULL));
|
| 432 |
-
|
| 433 |
-
printf("READY\n"); fflush(stdout); // Python waits for this
|
| 434 |
|
| 435 |
std::string line;
|
| 436 |
while (std::getline(std::cin, line)) {
|
| 437 |
if (!line.empty() && line.back()=='\r') line.pop_back();
|
| 438 |
if (line.empty()) continue;
|
| 439 |
-
if (line == "QUIT")
|
| 440 |
-
else if (line.rfind("RESET|",0)==0)
|
| 441 |
-
else if (line.rfind("REQUEST|",0)==0)
|
| 442 |
else { printf("ERROR unknown_cmd\n"); fflush(stdout); }
|
| 443 |
}
|
| 444 |
|
|
@@ -446,4 +414,4 @@ int main() {
|
|
| 446 |
free(g_model_data);
|
| 447 |
free(g_x); free(g_buf); free(g_qkv); free(g_attn); free(g_ff); free(g_logits);
|
| 448 |
return 0;
|
| 449 |
-
}
|
|
|
|
| 3 |
* KVInfer β PERSISTENT DAEMON INFERENCE ENGINE v2.0
|
| 4 |
* ============================================================
|
| 5 |
*
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
* ββ STDIN PROTOCOL ββββββββββββββββββββββββββββββββββββββββββ
|
| 7 |
* REQUEST|<sess>|<new_tokens_csv>|<max_new>|<temp>|<top_k>|<stop_csv>
|
| 8 |
* RESET|<sess>
|
| 9 |
* QUIT
|
| 10 |
*
|
| 11 |
* ββ STDOUT PROTOCOL βββββββββββββββββββββββββββββββββββββββββ
|
| 12 |
+
* READY
|
| 13 |
+
* TOKEN <id> <elapsed_ms>
|
| 14 |
+
* DONE <count> <total_ms>
|
| 15 |
+
* RESET_OK
|
| 16 |
* ERROR <message>
|
| 17 |
*
|
| 18 |
+
* ββ COMPILE (GCC / Linux) βββββββββββββββββββββββββββββββββββ
|
| 19 |
+
* g++ -O3 -march=native -fopenmp -ffast-math -std=c++17 -o inference inference.cpp
|
|
|
|
|
|
|
|
|
|
| 20 |
* ============================================================
|
| 21 |
*/
|
|
|
|
| 22 |
#include <stdio.h>
|
| 23 |
#include <stdlib.h>
|
| 24 |
#include <math.h>
|
| 25 |
#include <string.h>
|
| 26 |
+
#include <iostream>
|
| 27 |
#include <time.h>
|
| 28 |
#include <algorithm>
|
| 29 |
#include <string>
|
|
|
|
| 31 |
#include <unordered_set>
|
| 32 |
#include <vector>
|
| 33 |
#include <immintrin.h> // AVX2 + FMA
|
|
|
|
| 34 |
#ifdef _OPENMP
|
| 35 |
#include <omp.h>
|
| 36 |
#endif
|
|
|
|
| 37 |
#ifdef _WIN32
|
| 38 |
#include <windows.h>
|
| 39 |
static double get_time_ms() {
|
|
|
|
| 53 |
// βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 54 |
// Model Structures
|
| 55 |
// βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
|
|
|
| 56 |
typedef struct { int n_layer, n_head, n_embd, block_size, vocab_size; } Config;
|
|
|
|
| 57 |
typedef struct {
|
| 58 |
float *wte, *wpe;
|
| 59 |
float **ln1_w, **ln1_b;
|
|
|
|
| 69 |
struct SessionState {
|
| 70 |
float* k_cache = nullptr;
|
| 71 |
float* v_cache = nullptr;
|
| 72 |
+
int pos = 0;
|
| 73 |
double last_used = 0.0;
|
| 74 |
};
|
| 75 |
|
|
|
|
| 77 |
static Weights W;
|
| 78 |
static float* g_model_data = nullptr;
|
| 79 |
|
| 80 |
+
// βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 81 |
+
// MAX_SESSIONS β 3 engines Γ 14 sessions Γ 96MB = ~4GB KV cache
|
| 82 |
+
// Total RAM: ~6.57GB (safe under HF 8GB)
|
| 83 |
+
// βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 84 |
+
static const int MAX_SESSIONS = 14;
|
| 85 |
+
|
| 86 |
static std::unordered_map<std::string, SessionState> g_sessions;
|
| 87 |
|
| 88 |
// Shared per-request working buffers
|
| 89 |
static float *g_x, *g_buf, *g_qkv, *g_attn, *g_ff, *g_logits;
|
| 90 |
|
| 91 |
// βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 92 |
+
// Math Kernels (AVX2 + FMA + OpenMP)
|
| 93 |
// βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
|
|
|
| 94 |
static void layer_norm(float* out, const float* x, const float* w,
|
| 95 |
const float* b, int N) {
|
| 96 |
float mean = 0.f, var = 0.f;
|
|
|
|
| 147 |
}
|
| 148 |
|
| 149 |
// βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 150 |
+
// Transformer Forward (single token at position `pos`)
|
|
|
|
| 151 |
// βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
|
|
|
| 152 |
static void forward(int token_id, int pos, float* k_cache, float* v_cache) {
|
| 153 |
const int C = cfg.n_embd, H = cfg.n_head, hs = C/H;
|
|
|
|
| 154 |
float* te = W.wte + (long long)token_id*C;
|
| 155 |
float* pe = W.wpe + (long long)pos*C;
|
| 156 |
#pragma omp parallel for
|
| 157 |
for (int i = 0; i < C; i++) g_x[i] = te[i]+pe[i];
|
| 158 |
|
| 159 |
for (int l = 0; l < cfg.n_layer; l++) {
|
|
|
|
| 160 |
// Self-attention
|
| 161 |
layer_norm(g_buf, g_x, W.ln1_w[l], W.ln1_b[l], C);
|
| 162 |
matmul_vec(g_qkv, W.c_attn_w[l], g_buf, 3*C, C);
|
|
|
|
| 211 |
// βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 212 |
// Weight Mapping
|
| 213 |
// βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
|
|
|
| 214 |
static void map_weights(float* data) {
|
| 215 |
float* p = data;
|
| 216 |
const int C = cfg.n_embd, L = cfg.n_layer;
|
| 217 |
W.wte=p; p+=(long long)cfg.vocab_size*C;
|
| 218 |
W.wpe=p; p+=(long long)cfg.block_size*C;
|
|
|
|
| 219 |
#define ARR(f) W.f=(float**)malloc(L*sizeof(float*))
|
| 220 |
ARR(ln1_w); ARR(ln1_b); ARR(c_attn_w); ARR(c_attn_b);
|
| 221 |
ARR(c_proj_w); ARR(c_proj_b); ARR(ln2_w); ARR(ln2_b);
|
| 222 |
ARR(fc_w); ARR(fc_b); ARR(mlp_proj_w); ARR(mlp_proj_b);
|
| 223 |
#undef ARR
|
|
|
|
| 224 |
for (int l = 0; l < L; l++) {
|
| 225 |
+
W.ln1_w[l]=p; p+=C; W.ln1_b[l]=p; p+=C;
|
| 226 |
W.c_attn_w[l]=p; p+=3LL*C*C; W.c_attn_b[l]=p; p+=3LL*C;
|
| 227 |
W.c_proj_w[l]=p; p+=1LL*C*C; W.c_proj_b[l]=p; p+=C;
|
| 228 |
+
W.ln2_w[l]=p; p+=C; W.ln2_b[l]=p; p+=C;
|
| 229 |
+
W.fc_w[l]=p; p+=4LL*C*C; W.fc_b[l]=p; p+=4LL*C;
|
| 230 |
W.mlp_proj_w[l]=p; p+=1LL*C*4*C; W.mlp_proj_b[l]=p; p+=C;
|
| 231 |
}
|
| 232 |
W.ln_f_w=p; p+=C; W.ln_f_b=p; p+=C; W.lm_head_w=p;
|
| 233 |
}
|
| 234 |
|
| 235 |
// βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 236 |
+
// Session Management (LRU eviction when MAX_SESSIONS reached)
|
| 237 |
// βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
|
|
|
| 238 |
static long long kv_alloc_bytes() {
|
| 239 |
return (long long)cfg.n_layer * cfg.block_size * cfg.n_embd * sizeof(float);
|
| 240 |
}
|
|
|
|
| 271 |
}
|
| 272 |
|
| 273 |
// βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 274 |
+
// Sampler (Top-K)
|
| 275 |
// βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
|
|
|
| 276 |
static int sample_topk(float temperature, int top_k) {
|
| 277 |
for (int v = 0; v < cfg.vocab_size; v++) g_logits[v] /= temperature;
|
| 278 |
std::vector<std::pair<float,int>> pairs(cfg.vocab_size);
|
| 279 |
for (int v = 0; v < cfg.vocab_size; v++) pairs[v]={g_logits[v],v};
|
| 280 |
std::partial_sort(pairs.begin(), pairs.begin()+top_k, pairs.end(),
|
| 281 |
+
[](const std::pair<float,int>& a, const std::pair<float,int>& b){
|
| 282 |
+
return a.first > b.first;
|
| 283 |
+
});
|
| 284 |
float sum=0.f;
|
| 285 |
for (int j=0; j<top_k; j++) { pairs[j].first=expf(pairs[j].first); sum+=pairs[j].first; }
|
| 286 |
for (int j=0; j<top_k; j++) pairs[j].first /= sum;
|
|
|
|
| 293 |
// βββββββββββββββββββββββββββββββββββββββββοΏ½οΏ½βββββββββββββββββββββββββββββββ
|
| 294 |
// Helpers
|
| 295 |
// βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
|
|
|
| 296 |
static std::vector<std::string> split(const std::string& s, char d) {
|
| 297 |
std::vector<std::string> out; std::string cur;
|
| 298 |
for (char c:s){ if(c==d){out.push_back(cur);cur.clear();}else cur+=c; }
|
|
|
|
| 309 |
// βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 310 |
// Command Handlers
|
| 311 |
// βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
|
|
|
|
|
|
| 312 |
static void handle_request(const std::string& line) {
|
| 313 |
auto parts = split(line, '|');
|
| 314 |
if (parts.size() < 7) {
|
| 315 |
printf("ERROR bad_request_format\n"); fflush(stdout); return;
|
| 316 |
}
|
| 317 |
+
std::string sess_id = parts[1];
|
| 318 |
+
auto new_tokens = parse_ints(parts[2]);
|
| 319 |
+
int max_new = atoi(parts[3].c_str());
|
| 320 |
+
float temp = (float)atof(parts[4].c_str());
|
| 321 |
+
int top_k = atoi(parts[5].c_str());
|
| 322 |
+
auto stop_list = parse_ints(parts[6]);
|
| 323 |
|
| 324 |
if (temp < 0.01f) temp = 0.01f;
|
| 325 |
if (top_k < 1) top_k = 1;
|
|
|
|
| 327 |
if (max_new < 1) max_new = 1;
|
| 328 |
|
| 329 |
std::unordered_set<int> stop_ids(stop_list.begin(), stop_list.end());
|
| 330 |
+
stop_ids.insert(50256); // <|endoftext|> always stop
|
| 331 |
|
| 332 |
SessionState& sess = get_or_create(sess_id);
|
| 333 |
|
| 334 |
+
// Prefill new tokens into KV cache
|
| 335 |
for (int tok : new_tokens) {
|
| 336 |
if (sess.pos >= cfg.block_size) {
|
| 337 |
printf("ERROR context_window_full\n"); fflush(stdout); return;
|
|
|
|
| 340 |
sess.pos++;
|
| 341 |
}
|
| 342 |
|
| 343 |
+
// Autoregressive generation
|
| 344 |
double t0 = get_time_ms();
|
| 345 |
int gen = 0;
|
|
|
|
| 346 |
for (int i = 0; i < max_new; i++) {
|
| 347 |
if (sess.pos >= cfg.block_size) break;
|
| 348 |
int best = sample_topk(temp, top_k);
|
|
|
|
| 358 |
fflush(stdout);
|
| 359 |
}
|
| 360 |
|
|
|
|
| 361 |
static void handle_reset(const std::string& line) {
|
| 362 |
auto parts = split(line, '|');
|
| 363 |
if (parts.size() < 2) { printf("RESET_OK\n"); fflush(stdout); return; }
|
|
|
|
| 370 |
}
|
| 371 |
|
| 372 |
// βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 373 |
+
// MAIN β model ek baar load, phir stdin se commands serve karo
|
| 374 |
// βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
|
|
|
| 375 |
int main() {
|
| 376 |
FILE* f = fopen("model.bin", "rb");
|
| 377 |
if (!f) { printf("ERROR model.bin_not_found\n"); fflush(stdout); return 1; }
|
|
|
|
| 380 |
fseek(f, 0, SEEK_END);
|
| 381 |
long fsize = ftell(f);
|
| 382 |
fseek(f, 5*(long)sizeof(int), SEEK_SET);
|
|
|
|
| 383 |
long wbytes = fsize - 5*(long)sizeof(int);
|
| 384 |
+
|
| 385 |
g_model_data = (float*)malloc(wbytes);
|
| 386 |
if (!g_model_data) { printf("ERROR oom_loading_model\n"); fflush(stdout); return 1; }
|
| 387 |
fread(g_model_data, 1, wbytes, f);
|
| 388 |
fclose(f);
|
| 389 |
+
|
| 390 |
map_weights(g_model_data);
|
| 391 |
|
| 392 |
const int C = cfg.n_embd;
|
|
|
|
| 398 |
g_logits = (float*)malloc(cfg.vocab_size*sizeof(float));
|
| 399 |
|
| 400 |
srand((unsigned int)time(NULL));
|
| 401 |
+
printf("READY\n"); fflush(stdout); // Python waits for this
|
|
|
|
| 402 |
|
| 403 |
std::string line;
|
| 404 |
while (std::getline(std::cin, line)) {
|
| 405 |
if (!line.empty() && line.back()=='\r') line.pop_back();
|
| 406 |
if (line.empty()) continue;
|
| 407 |
+
if (line == "QUIT") break;
|
| 408 |
+
else if (line.rfind("RESET|",0)==0) handle_reset(line);
|
| 409 |
+
else if (line.rfind("REQUEST|",0)==0) handle_request(line);
|
| 410 |
else { printf("ERROR unknown_cmd\n"); fflush(stdout); }
|
| 411 |
}
|
| 412 |
|
|
|
|
| 414 |
free(g_model_data);
|
| 415 |
free(g_x); free(g_buf); free(g_qkv); free(g_attn); free(g_ff); free(g_logits);
|
| 416 |
return 0;
|
| 417 |
+
}
|