| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| #include <stdio.h> |
| #include <stdlib.h> |
| #include <math.h> |
| #include <string.h> |
| #include <iostream> |
| #include <time.h> |
| #include <algorithm> |
| #include <string> |
| #include <unordered_map> |
| #include <unordered_set> |
| #include <vector> |
| #include <immintrin.h> |
| #ifdef _OPENMP |
| #include <omp.h> |
| #endif |
| #ifdef _WIN32 |
| #include <windows.h> |
| static double get_time_ms() { |
| LARGE_INTEGER f, c; |
| QueryPerformanceFrequency(&f); |
| QueryPerformanceCounter(&c); |
| return (double)c.QuadPart / f.QuadPart * 1000.0; |
| } |
| #else |
| #include <sys/time.h> |
| static double get_time_ms() { |
| struct timeval tv; gettimeofday(&tv, NULL); |
| return tv.tv_sec * 1000.0 + tv.tv_usec / 1000.0; |
| } |
| #endif |
|
|
| |
| |
| |
| typedef struct { int n_layer, n_head, n_embd, block_size, vocab_size; } Config; |
| typedef struct { |
| float *wte, *wpe; |
| float **ln1_w, **ln1_b; |
| float **c_attn_w, **c_attn_b; |
| float **c_proj_w, **c_proj_b; |
| float **ln2_w, **ln2_b; |
| float **fc_w, **fc_b; |
| float **mlp_proj_w, **mlp_proj_b; |
| float *ln_f_w, *ln_f_b; |
| float *lm_head_w; |
| } Weights; |
|
|
| struct SessionState { |
| float* k_cache = nullptr; |
| float* v_cache = nullptr; |
| int pos = 0; |
| double last_used = 0.0; |
| }; |
|
|
| static Config cfg; |
| static Weights W; |
| static float* g_model_data = nullptr; |
|
|
| |
| |
| |
| |
| static const int MAX_SESSIONS = 14; |
|
|
| static std::unordered_map<std::string, SessionState> g_sessions; |
|
|
| |
| static float *g_x, *g_buf, *g_qkv, *g_attn, *g_ff, *g_logits; |
|
|
| |
| |
| |
| static void layer_norm(float* out, const float* x, const float* w, |
| const float* b, int N) { |
| float mean = 0.f, var = 0.f; |
| for (int i = 0; i < N; i++) mean += x[i]; |
| mean /= N; |
| for (int i = 0; i < N; i++) { float d=x[i]-mean; var+=d*d; } |
| var /= N; |
| float sc = 1.f/sqrtf(var+1e-5f); |
| for (int i = 0; i < N; i++) out[i]=(x[i]-mean)*sc*w[i]+b[i]; |
| } |
|
|
| static void matmul_vec(float* out, const float* mat, const float* x, |
| int M, int K) { |
| #pragma omp parallel for schedule(static) |
| for (int i = 0; i < M; i++) { |
| const float* row = mat + (long long)i*K; |
| __m256 acc = _mm256_setzero_ps(); |
| int j = 0; |
| for (; j <= K-8; j += 8) |
| acc = _mm256_fmadd_ps(_mm256_loadu_ps(row+j), |
| _mm256_loadu_ps(x+j), acc); |
| float tmp[8]; _mm256_storeu_ps(tmp, acc); |
| float s = tmp[0]+tmp[1]+tmp[2]+tmp[3]+tmp[4]+tmp[5]+tmp[6]+tmp[7]; |
| for (; j < K; j++) s += row[j]*x[j]; |
| out[i] = s; |
| } |
| } |
|
|
| static inline void add_bias(float* x, const float* b, int N) { |
| #pragma omp parallel for |
| for (int i = 0; i < N; i++) x[i] += b[i]; |
| } |
|
|
| static inline void residual_add(float* x, const float* y, int N) { |
| #pragma omp parallel for |
| for (int i = 0; i < N; i++) x[i] += y[i]; |
| } |
|
|
| static void gelu_inplace(float* x, int N) { |
| const float c = 0.7978845608f; |
| #pragma omp parallel for |
| for (int i = 0; i < N; i++) { |
| float v = x[i]; |
| x[i] = 0.5f*v*(1.f+tanhf(c*(v+0.044715f*v*v*v))); |
| } |
| } |
|
|
| static void softmax_inplace(float* x, int N) { |
| float mx = x[0]; |
| for (int i = 1; i < N; i++) if (x[i]>mx) mx=x[i]; |
| float s = 0.f; |
| for (int i = 0; i < N; i++) { x[i]=expf(x[i]-mx); s+=x[i]; } |
| for (int i = 0; i < N; i++) x[i] /= s; |
| } |
|
|
| |
| |
| |
| static void forward(int token_id, int pos, float* k_cache, float* v_cache) { |
| const int C = cfg.n_embd, H = cfg.n_head, hs = C/H; |
| float* te = W.wte + (long long)token_id*C; |
| float* pe = W.wpe + (long long)pos*C; |
| #pragma omp parallel for |
| for (int i = 0; i < C; i++) g_x[i] = te[i]+pe[i]; |
|
|
| for (int l = 0; l < cfg.n_layer; l++) { |
| |
| layer_norm(g_buf, g_x, W.ln1_w[l], W.ln1_b[l], C); |
| matmul_vec(g_qkv, W.c_attn_w[l], g_buf, 3*C, C); |
| add_bias(g_qkv, W.c_attn_b[l], 3*C); |
|
|
| float* q = g_qkv, *k = g_qkv+C, *v = g_qkv+2*C; |
| float* kc = k_cache + (long long)l*cfg.block_size*C; |
| float* vc = v_cache + (long long)l*cfg.block_size*C; |
| memcpy(kc+(long long)pos*C, k, C*sizeof(float)); |
| memcpy(vc+(long long)pos*C, v, C*sizeof(float)); |
|
|
| #pragma omp parallel for schedule(static) |
| for (int h = 0; h < H; h++) { |
| float* qh = q + h*hs; |
| float sc = 1.f/sqrtf((float)hs); |
| float* la = g_attn + h*cfg.block_size; |
| for (int t = 0; t <= pos; t++) { |
| float* kh = kc+(long long)t*C+h*hs; |
| float dot = 0.f; |
| for (int d = 0; d < hs; d++) dot += qh[d]*kh[d]; |
| la[t] = dot*sc; |
| } |
| softmax_inplace(la, pos+1); |
| float* oh = g_buf+h*hs; |
| memset(oh, 0, hs*sizeof(float)); |
| for (int t = 0; t <= pos; t++) { |
| float* vh = vc+(long long)t*C+h*hs; |
| float a = la[t]; |
| for (int d = 0; d < hs; d++) oh[d] += a*vh[d]; |
| } |
| } |
|
|
| float* ao = g_qkv; |
| matmul_vec(ao, W.c_proj_w[l], g_buf, C, C); |
| add_bias(ao, W.c_proj_b[l], C); |
| residual_add(g_x, ao, C); |
|
|
| |
| layer_norm(g_buf, g_x, W.ln2_w[l], W.ln2_b[l], C); |
| matmul_vec(g_ff, W.fc_w[l], g_buf, 4*C, C); |
| add_bias(g_ff, W.fc_b[l], 4*C); |
| gelu_inplace(g_ff, 4*C); |
| matmul_vec(g_buf, W.mlp_proj_w[l], g_ff, C, 4*C); |
| add_bias(g_buf, W.mlp_proj_b[l], C); |
| residual_add(g_x, g_buf, C); |
| } |
|
|
| layer_norm(g_buf, g_x, W.ln_f_w, W.ln_f_b, C); |
| matmul_vec(g_logits, W.lm_head_w, g_buf, cfg.vocab_size, C); |
| } |
|
|
| |
| |
| |
| static void map_weights(float* data) { |
| float* p = data; |
| const int C = cfg.n_embd, L = cfg.n_layer; |
| W.wte=p; p+=(long long)cfg.vocab_size*C; |
| W.wpe=p; p+=(long long)cfg.block_size*C; |
| #define ARR(f) W.f=(float**)malloc(L*sizeof(float*)) |
| ARR(ln1_w); ARR(ln1_b); ARR(c_attn_w); ARR(c_attn_b); |
| ARR(c_proj_w); ARR(c_proj_b); ARR(ln2_w); ARR(ln2_b); |
| ARR(fc_w); ARR(fc_b); ARR(mlp_proj_w); ARR(mlp_proj_b); |
| #undef ARR |
| for (int l = 0; l < L; l++) { |
| W.ln1_w[l]=p; p+=C; W.ln1_b[l]=p; p+=C; |
| W.c_attn_w[l]=p; p+=3LL*C*C; W.c_attn_b[l]=p; p+=3LL*C; |
| W.c_proj_w[l]=p; p+=1LL*C*C; W.c_proj_b[l]=p; p+=C; |
| W.ln2_w[l]=p; p+=C; W.ln2_b[l]=p; p+=C; |
| W.fc_w[l]=p; p+=4LL*C*C; W.fc_b[l]=p; p+=4LL*C; |
| W.mlp_proj_w[l]=p; p+=1LL*C*4*C; W.mlp_proj_b[l]=p; p+=C; |
| } |
| W.ln_f_w=p; p+=C; W.ln_f_b=p; p+=C; W.lm_head_w=p; |
| } |
|
|
| |
| |
| |
| static long long kv_alloc_bytes() { |
| return (long long)cfg.n_layer * cfg.block_size * cfg.n_embd * sizeof(float); |
| } |
|
|
| static void free_session(SessionState& s) { |
| free(s.k_cache); free(s.v_cache); |
| s.k_cache=nullptr; s.v_cache=nullptr; s.pos=0; |
| } |
|
|
| static void evict_oldest() { |
| if (g_sessions.empty()) return; |
| std::string oid; double ot = 1e300; |
| for (auto& kv : g_sessions) |
| if (kv.second.last_used < ot) { ot=kv.second.last_used; oid=kv.first; } |
| free_session(g_sessions[oid]); |
| g_sessions.erase(oid); |
| } |
|
|
| static SessionState& get_or_create(const std::string& id) { |
| auto it = g_sessions.find(id); |
| if (it != g_sessions.end()) { |
| it->second.last_used = get_time_ms(); |
| return it->second; |
| } |
| if ((int)g_sessions.size() >= MAX_SESSIONS) evict_oldest(); |
| SessionState s; |
| long long nb = kv_alloc_bytes(); |
| s.k_cache = (float*)calloc(nb, 1); |
| s.v_cache = (float*)calloc(nb, 1); |
| s.pos = 0; |
| s.last_used = get_time_ms(); |
| g_sessions[id] = s; |
| return g_sessions[id]; |
| } |
|
|
| |
| |
| |
| static int sample_topk(float temperature, int top_k) { |
| for (int v = 0; v < cfg.vocab_size; v++) g_logits[v] /= temperature; |
| std::vector<std::pair<float,int>> pairs(cfg.vocab_size); |
| for (int v = 0; v < cfg.vocab_size; v++) pairs[v]={g_logits[v],v}; |
| std::partial_sort(pairs.begin(), pairs.begin()+top_k, pairs.end(), |
| [](const std::pair<float,int>& a, const std::pair<float,int>& b){ |
| return a.first > b.first; |
| }); |
| float sum=0.f; |
| for (int j=0; j<top_k; j++) { pairs[j].first=expf(pairs[j].first); sum+=pairs[j].first; } |
| for (int j=0; j<top_k; j++) pairs[j].first /= sum; |
| float r=(float)rand()/((float)RAND_MAX+1.f), cum=0.f; |
| int best=pairs[0].second; |
| for (int j=0; j<top_k; j++) { cum+=pairs[j].first; if(r<cum){best=pairs[j].second;break;} } |
| return best; |
| } |
|
|
| |
| |
| |
| static std::vector<std::string> split(const std::string& s, char d) { |
| std::vector<std::string> out; std::string cur; |
| for (char c:s){ if(c==d){out.push_back(cur);cur.clear();}else cur+=c; } |
| out.push_back(cur); return out; |
| } |
|
|
| static std::vector<int> parse_ints(const std::string& s) { |
| std::vector<int> out; |
| if (s.empty()) return out; |
| for (auto& t : split(s,',')) if(!t.empty()) out.push_back(atoi(t.c_str())); |
| return out; |
| } |
|
|
| |
| |
| |
| static void handle_request(const std::string& line) { |
| auto parts = split(line, '|'); |
| if (parts.size() < 7) { |
| printf("ERROR bad_request_format\n"); fflush(stdout); return; |
| } |
| std::string sess_id = parts[1]; |
| auto new_tokens = parse_ints(parts[2]); |
| int max_new = atoi(parts[3].c_str()); |
| float temp = (float)atof(parts[4].c_str()); |
| int top_k = atoi(parts[5].c_str()); |
| auto stop_list = parse_ints(parts[6]); |
|
|
| if (temp < 0.01f) temp = 0.01f; |
| if (top_k < 1) top_k = 1; |
| if (top_k > cfg.vocab_size) top_k = cfg.vocab_size; |
| if (max_new < 1) max_new = 1; |
|
|
| std::unordered_set<int> stop_ids(stop_list.begin(), stop_list.end()); |
| stop_ids.insert(50256); |
|
|
| SessionState& sess = get_or_create(sess_id); |
|
|
| |
| for (int tok : new_tokens) { |
| if (sess.pos >= cfg.block_size) { |
| printf("ERROR context_window_full\n"); fflush(stdout); return; |
| } |
| forward(tok, sess.pos, sess.k_cache, sess.v_cache); |
| sess.pos++; |
| } |
|
|
| |
| double t0 = get_time_ms(); |
| int gen = 0; |
| for (int i = 0; i < max_new; i++) { |
| if (sess.pos >= cfg.block_size) break; |
| int best = sample_topk(temp, top_k); |
| printf("TOKEN %d %.2f\n", best, get_time_ms()-t0); |
| fflush(stdout); |
| gen++; |
| if (stop_ids.count(best)) break; |
| forward(best, sess.pos, sess.k_cache, sess.v_cache); |
| sess.pos++; |
| } |
|
|
| printf("DONE %d %.2f\n", gen, get_time_ms()-t0); |
| fflush(stdout); |
| } |
|
|
| static void handle_reset(const std::string& line) { |
| auto parts = split(line, '|'); |
| if (parts.size() < 2) { printf("RESET_OK\n"); fflush(stdout); return; } |
| auto it = g_sessions.find(parts[1]); |
| if (it != g_sessions.end()) { |
| free_session(it->second); |
| g_sessions.erase(it); |
| } |
| printf("RESET_OK\n"); fflush(stdout); |
| } |
|
|
| |
| |
| |
| int main() { |
| FILE* f = fopen("model.bin", "rb"); |
| if (!f) { printf("ERROR model.bin_not_found\n"); fflush(stdout); return 1; } |
|
|
| fread(&cfg, sizeof(int), 5, f); |
| fseek(f, 0, SEEK_END); |
| long fsize = ftell(f); |
| fseek(f, 5*(long)sizeof(int), SEEK_SET); |
| long wbytes = fsize - 5*(long)sizeof(int); |
|
|
| g_model_data = (float*)malloc(wbytes); |
| if (!g_model_data) { printf("ERROR oom_loading_model\n"); fflush(stdout); return 1; } |
| fread(g_model_data, 1, wbytes, f); |
| fclose(f); |
|
|
| map_weights(g_model_data); |
|
|
| const int C = cfg.n_embd; |
| g_x = (float*)malloc(C*sizeof(float)); |
| g_buf = (float*)malloc(C*sizeof(float)); |
| g_qkv = (float*)malloc(3*C*sizeof(float)); |
| g_attn = (float*)malloc(cfg.n_head*cfg.block_size*sizeof(float)); |
| g_ff = (float*)malloc(4*C*sizeof(float)); |
| g_logits = (float*)malloc(cfg.vocab_size*sizeof(float)); |
|
|
| srand((unsigned int)time(NULL)); |
| printf("READY\n"); fflush(stdout); |
|
|
| std::string line; |
| while (std::getline(std::cin, line)) { |
| if (!line.empty() && line.back()=='\r') line.pop_back(); |
| if (line.empty()) continue; |
| if (line == "QUIT") break; |
| else if (line.rfind("RESET|",0)==0) handle_reset(line); |
| else if (line.rfind("REQUEST|",0)==0) handle_request(line); |
| else { printf("ERROR unknown_cmd\n"); fflush(stdout); } |
| } |
|
|
| for (auto& kv : g_sessions) free_session(kv.second); |
| free(g_model_data); |
| free(g_x); free(g_buf); free(g_qkv); free(g_attn); free(g_ff); free(g_logits); |
| return 0; |
| } |