/* * ============================================================ * KVInfer — Llama 1B Inference Engine v1.0 * ============================================================ * * Architecture: * RMSNorm · RoPE · GQA (n_kv_head != n_head) · SwiGLU MLP * AVX2 + FMA matmul · OpenMP parallelism · KV-Cache * * ── STDIN PROTOCOL ────────────────────────────────────────── * REQUEST|||||| * RESET| * QUIT * * ── STDOUT PROTOCOL ───────────────────────────────────────── * READY * TOKEN * DONE * RESET_OK * ERROR * * ── COMPILE ───────────────────────────────────────────────── * g++ -O3 -march=native -fopenmp -ffast-math -std=c++17 \ * -o inference inference.cpp -lm * ============================================================ */ #include #include #include #include #include #include #include #include #include #include #include #include #include #ifdef _OPENMP #include #endif #ifdef _WIN32 #include static double get_ms(){LARGE_INTEGER f,c;QueryPerformanceFrequency(&f);QueryPerformanceCounter(&c);return(double)c.QuadPart/f.QuadPart*1000.0;} #else #include static double get_ms(){struct timeval tv;gettimeofday(&tv,NULL);return tv.tv_sec*1000.0+tv.tv_usec/1000.0;} #endif // ───────────────────────────────────────────────────────────────────────── // Config (filled from binary header) // ───────────────────────────────────────────────────────────────────────── struct Config { int n_layer, n_head, n_kv_head, n_embd, n_intermediate, vocab_size, max_seq_len; float rope_theta; }; // ───────────────────────────────────────────────────────────────────────── // Weights (pointers into mmap'd / malloc'd buffer) // ───────────────────────────────────────────────────────────────────────── struct Weights { float* embed_tokens; // [vocab_size, n_embd] // Per-layer arrays (size n_layer each) float** rms_att; // [n_embd] float** q_proj; // [n_head * hs, n_embd] float** k_proj; // [n_kv_head * hs, n_embd] float** v_proj; // [n_kv_head * hs, n_embd] float** o_proj; // [n_embd, n_head * hs] float** rms_ffn; // [n_embd] float** gate_proj; // [n_intermediate, n_embd] float** up_proj; // [n_intermediate, n_embd] float** down_proj; // [n_embd, n_intermediate] float* rms_final; // [n_embd] float* lm_head; // [vocab_size, n_embd] }; static Config cfg; static Weights W; static float* g_data = nullptr; // ───────────────────────────────────────────────────────────────────────── // Session (per-session KV cache + position) // ───────────────────────────────────────────────────────────────────────── struct Session { float* k_cache = nullptr; // [n_layer, max_seq_len, n_kv_head * hs] float* v_cache = nullptr; int pos = 0; double last_use = 0.0; }; static const int MAX_SESSIONS = 8; static std::unordered_map g_sessions; // ───────────────────────────────────────────────────────────────────────── // Working buffers (shared across requests, single-threaded per request) // ───────────────────────────────────────────────────────────────────────── static float *g_x, *g_xb, *g_q, *g_k, *g_v, *g_attn, *g_ff_gate, *g_ff_up, *g_logits; // ───────────────────────────────────────────────────────────────────────── // Math Kernels // ───────────────────────────────────────────────────────────────────────── // RMSNorm: out[i] = x[i] / rms(x) * w[i] static void rmsnorm(float* out, const float* x, const float* w, int N) { float ss = 0.0f; for (int i = 0; i < N; i++) ss += x[i] * x[i]; ss = 1.0f / sqrtf(ss / N + 1e-5f); for (int i = 0; i < N; i++) out[i] = x[i] * ss * w[i]; } // AVX2 + FMA matrix-vector multiply: out[M] = mat[M,K] * x[K] static void matmul(float* out, const float* mat, const float* x, int M, int K) { #pragma omp parallel for schedule(static) for (int i = 0; i < M; i++) { const float* row = mat + (long long)i * K; __m256 acc = _mm256_setzero_ps(); int j = 0; for (; j <= K - 8; j += 8) acc = _mm256_fmadd_ps(_mm256_loadu_ps(row + j), _mm256_loadu_ps(x + j), acc); float tmp[8]; _mm256_storeu_ps(tmp, acc); float s = tmp[0]+tmp[1]+tmp[2]+tmp[3]+tmp[4]+tmp[5]+tmp[6]+tmp[7]; for (; j < K; j++) s += row[j] * x[j]; out[i] = s; } } // SwiGLU: out[i] = silu(gate[i]) * up[i] // silu(x) = x * sigmoid(x) = x / (1 + exp(-x)) static void swiglu(float* out, const float* gate, const float* up, int N) { #pragma omp parallel for for (int i = 0; i < N; i++) { float g = gate[i]; float silu = g / (1.0f + expf(-g)); out[i] = silu * up[i]; } } // Softmax in-place over first n elements static void softmax(float* x, int n) { float mx = x[0]; for (int i = 1; i < n; i++) if (x[i] > mx) mx = x[i]; float s = 0.0f; for (int i = 0; i < n; i++) { x[i] = expf(x[i] - mx); s += x[i]; } for (int i = 0; i < n; i++) x[i] /= s; } // ───────────────────────────────────────────────────────────────────────── // RoPE (Rotary Position Embedding) // Apply in-place to a query/key vector of length dim (for one head) // ───────────────────────────────────────────────────────────────────────── static void rope(float* x, int pos, int dim, float theta) { for (int i = 0; i < dim; i += 2) { float freq = 1.0f / powf(theta, (float)i / dim); float angle = pos * freq; float c = cosf(angle), s = sinf(angle); float x0 = x[i], x1 = x[i + 1]; x[i] = x0 * c - x1 * s; x[i + 1] = x0 * s + x1 * c; } } // ───────────────────────────────────────────────────────────────────────── // Forward (single token at position pos) // ───────────────────────────────────────────────────────────────────────── static void forward(int token_id, int pos, float* k_cache, float* v_cache) { const int C = cfg.n_embd; const int H = cfg.n_head; const int KVH = cfg.n_kv_head; const int hs = C / H; // head dim = 64 const int kv_dim = KVH * hs; // KV dim = 512 const int GRP = H / KVH; // heads per KV group = 4 // Token embedding memcpy(g_x, W.embed_tokens + (long long)token_id * C, C * sizeof(float)); for (int l = 0; l < cfg.n_layer; l++) { // ── Attention ───────────────────────────────────────────────────── // Input RMSNorm rmsnorm(g_xb, g_x, W.rms_att[l], C); // Q, K, V projections (no bias in Llama) matmul(g_q, W.q_proj[l], g_xb, C, C); // [H*hs] matmul(g_k, W.k_proj[l], g_xb, kv_dim, C); // [KVH*hs] matmul(g_v, W.v_proj[l], g_xb, kv_dim, C); // [KVH*hs] // RoPE on Q and K (per-head) for (int h = 0; h < H; h++) rope(g_q + h*hs, pos, hs, cfg.rope_theta); for (int h = 0; h < KVH; h++) rope(g_k + h*hs, pos, hs, cfg.rope_theta); // Store K, V into cache for this layer float* kc = k_cache + (long long)l * cfg.max_seq_len * kv_dim; float* vc = v_cache + (long long)l * cfg.max_seq_len * kv_dim; memcpy(kc + (long long)pos * kv_dim, g_k, kv_dim * sizeof(float)); memcpy(vc + (long long)pos * kv_dim, g_v, kv_dim * sizeof(float)); // GQA Attention: for each Q head, attend to its KV group #pragma omp parallel for schedule(static) for (int h = 0; h < H; h++) { int kv_h = h / GRP; // which KV head float scale = 1.0f / sqrtf((float)hs); float* qh = g_q + h * hs; // Scores for (int t = 0; t <= pos; t++) { float* kh = kc + (long long)t * kv_dim + kv_h * hs; float dot = 0.0f; for (int d = 0; d < hs; d++) dot += qh[d] * kh[d]; g_attn[h * cfg.max_seq_len + t] = dot * scale; } softmax(g_attn + h * cfg.max_seq_len, pos + 1); // Weighted sum of V float* out_h = g_xb + h * hs; memset(out_h, 0, hs * sizeof(float)); for (int t = 0; t <= pos; t++) { float* vh = vc + (long long)t * kv_dim + kv_h * hs; float a = g_attn[h * cfg.max_seq_len + t]; for (int d = 0; d < hs; d++) out_h[d] += a * vh[d]; } } // O projection + residual float tmp_o[C]; // stack — ok for C = 2048 matmul(tmp_o, W.o_proj[l], g_xb, C, C); #pragma omp parallel for for (int i = 0; i < C; i++) g_x[i] += tmp_o[i]; // ── MLP (SwiGLU) ────────────────────────────────────────────────── rmsnorm(g_xb, g_x, W.rms_ffn[l], C); // gate and up projections in parallel matmul(g_ff_gate, W.gate_proj[l], g_xb, cfg.n_intermediate, C); matmul(g_ff_up, W.up_proj[l], g_xb, cfg.n_intermediate, C); // SwiGLU activation: ff = silu(gate) * up swiglu(g_ff_gate, g_ff_gate, g_ff_up, cfg.n_intermediate); // Down projection + residual float tmp_d[C]; matmul(tmp_d, W.down_proj[l], g_ff_gate, C, cfg.n_intermediate); #pragma omp parallel for for (int i = 0; i < C; i++) g_x[i] += tmp_d[i]; } // Final RMSNorm + LM head rmsnorm(g_xb, g_x, W.rms_final, C); matmul(g_logits, W.lm_head, g_xb, cfg.vocab_size, C); } // ───────────────────────────────────────────────────────────────────────── // Weight Mapping // ───────────────────────────────────────────────────────────────────────── static void map_weights(float* data) { const int C = cfg.n_embd; const int L = cfg.n_layer; const int KVH = cfg.n_kv_head; const int hs = C / cfg.n_head; const int kv_dim = KVH * hs; const int F = cfg.n_intermediate; float* p = data; W.embed_tokens = p; p += (long long)cfg.vocab_size * C; #define MK(f) W.f = (float**)malloc(L * sizeof(float*)) MK(rms_att); MK(q_proj); MK(k_proj); MK(v_proj); MK(o_proj); MK(rms_ffn); MK(gate_proj); MK(up_proj); MK(down_proj); #undef MK for (int l = 0; l < L; l++) { W.rms_att[l] = p; p += C; W.q_proj[l] = p; p += (long long)C * C; W.k_proj[l] = p; p += (long long)kv_dim * C; W.v_proj[l] = p; p += (long long)kv_dim * C; W.o_proj[l] = p; p += (long long)C * C; W.rms_ffn[l] = p; p += C; W.gate_proj[l] = p; p += (long long)F * C; W.up_proj[l] = p; p += (long long)F * C; W.down_proj[l] = p; p += (long long)C * F; } W.rms_final = p; p += C; W.lm_head = p; } // ───────────────────────────────────────────────────────────────────────── // Session Management (LRU evict) // ───────────────────────────────────────────────────────────────────────── static long long kv_bytes() { int kv_dim = cfg.n_kv_head * (cfg.n_embd / cfg.n_head); return (long long)cfg.n_layer * cfg.max_seq_len * kv_dim * sizeof(float); } static void free_sess(Session& s) { free(s.k_cache); free(s.v_cache); s.k_cache = nullptr; s.v_cache = nullptr; s.pos = 0; } static void evict_oldest() { if (g_sessions.empty()) return; std::string oid; double ot = 1e300; for (auto& kv : g_sessions) if (kv.second.last_use < ot) { ot = kv.second.last_use; oid = kv.first; } free_sess(g_sessions[oid]); g_sessions.erase(oid); } static Session& get_or_create(const std::string& id) { auto it = g_sessions.find(id); if (it != g_sessions.end()) { it->second.last_use = get_ms(); return it->second; } if ((int)g_sessions.size() >= MAX_SESSIONS) evict_oldest(); Session s; long long nb = kv_bytes(); s.k_cache = (float*)calloc(nb, 1); s.v_cache = (float*)calloc(nb, 1); s.pos = 0; s.last_use = get_ms(); g_sessions[id] = s; return g_sessions[id]; } // ───────────────────────────────────────────────────────────────────────── // Sampler (Top-K) // ───────────────────────────────────────────────────────────────────────── static int sample_topk(float temperature, int top_k) { for (int v = 0; v < cfg.vocab_size; v++) g_logits[v] /= temperature; int K = std::min(top_k, cfg.vocab_size); std::vector> pairs(cfg.vocab_size); for (int v = 0; v < cfg.vocab_size; v++) pairs[v] = {g_logits[v], v}; std::partial_sort(pairs.begin(), pairs.begin() + K, pairs.end(), [](const auto& a, const auto& b){ return a.first > b.first; }); float sum = 0.0f; for (int j = 0; j < K; j++) { pairs[j].first = expf(pairs[j].first); sum += pairs[j].first; } for (int j = 0; j < K; j++) pairs[j].first /= sum; float r = (float)rand() / ((float)RAND_MAX + 1.0f), cum = 0.0f, best = pairs[0].second; for (int j = 0; j < K; j++) { cum += pairs[j].first; if (r < cum) { best = pairs[j].second; break; } } return (int)best; } // ───────────────────────────────────────────────────────────────────────── // Helpers // ───────────────────────────────────────────────────────────────────────── static std::vector split(const std::string& s, char d) { std::vector out; std::string cur; for (char c : s) { if (c == d) { out.push_back(cur); cur.clear(); } else cur += c; } out.push_back(cur); return out; } static std::vector parse_ints(const std::string& s) { std::vector out; for (auto& t : split(s, ',')) if (!t.empty()) out.push_back(atoi(t.c_str())); return out; } // ───────────────────────────────────────────────────────────────────────── // Command Handlers // ───────────────────────────────────────────────────────────────────────── static void handle_request(const std::string& line) { auto parts = split(line, '|'); if (parts.size() < 7) { printf("ERROR bad_format\n"); fflush(stdout); return; } std::string sess_id = parts[1]; auto new_toks = parse_ints(parts[2]); int max_new = atoi(parts[3].c_str()); float temp = (float)atof(parts[4].c_str()); int top_k = atoi(parts[5].c_str()); auto stop_lst = parse_ints(parts[6]); temp = std::max(temp, 0.01f); top_k = std::clamp(top_k, 1, cfg.vocab_size); max_new = std::max(max_new, 1); std::unordered_set stop_ids(stop_lst.begin(), stop_lst.end()); stop_ids.insert(128009); // <|eot_id|> Llama 3 EOS stop_ids.insert(128001); // <|end_of_text|> Session& sess = get_or_create(sess_id); // Prefill for (int tok : new_toks) { if (sess.pos >= cfg.max_seq_len) { printf("ERROR context_full\n"); fflush(stdout); return; } forward(tok, sess.pos, sess.k_cache, sess.v_cache); sess.pos++; } // Autoregressive generation double t0 = get_ms(); int gen = 0; for (int i = 0; i < max_new; i++) { if (sess.pos >= cfg.max_seq_len) break; int next = sample_topk(temp, top_k); printf("TOKEN %d %.2f\n", next, get_ms() - t0); fflush(stdout); gen++; if (stop_ids.count(next)) break; forward(next, sess.pos, sess.k_cache, sess.v_cache); sess.pos++; } printf("DONE %d %.2f\n", gen, get_ms() - t0); fflush(stdout); } static void handle_reset(const std::string& line) { auto parts = split(line, '|'); if (parts.size() >= 2) { auto it = g_sessions.find(parts[1]); if (it != g_sessions.end()) { free_sess(it->second); g_sessions.erase(it); } } printf("RESET_OK\n"); fflush(stdout); } // ───────────────────────────────────────────────────────────────────────── // main // ───────────────────────────────────────────────────────────────────────── int main() { FILE* f = fopen("model_llama.bin", "rb"); if (!f) { printf("ERROR model_llama.bin not found\n"); fflush(stdout); return 1; } // Read header int hdr[7]; float theta; fread(hdr, sizeof(int), 7, f); fread(&theta, sizeof(float), 1, f); cfg = {hdr[0], hdr[1], hdr[2], hdr[3], hdr[4], hdr[5], hdr[6], theta}; printf("[engine] Layers=%d Heads=%d KVHeads=%d Embd=%d Inter=%d Vocab=%d Seq=%d Theta=%.0f\n", cfg.n_layer, cfg.n_head, cfg.n_kv_head, cfg.n_embd, cfg.n_intermediate, cfg.vocab_size, cfg.max_seq_len, cfg.rope_theta); fflush(stdout); // Load weights fseek(f, 0, SEEK_END); long fsize = ftell(f); long woff = 7 * sizeof(int) + sizeof(float); fseek(f, woff, SEEK_SET); long wbytes = fsize - woff; g_data = (float*)malloc(wbytes); if (!g_data) { printf("ERROR oom\n"); fflush(stdout); return 1; } fread(g_data, 1, wbytes, f); fclose(f); map_weights(g_data); // Working buffers const int C = cfg.n_embd; const int F = cfg.n_intermediate; const int S = cfg.max_seq_len; const int H = cfg.n_head; g_x = (float*)malloc(C * sizeof(float)); g_xb = (float*)malloc(C * sizeof(float)); g_q = (float*)malloc(C * sizeof(float)); // H * hs = C g_k = (float*)malloc(cfg.n_kv_head * (C/H) * sizeof(float)); g_v = (float*)malloc(cfg.n_kv_head * (C/H) * sizeof(float)); g_attn = (float*)malloc((long long)H * S * sizeof(float)); g_ff_gate = (float*)malloc(F * sizeof(float)); g_ff_up = (float*)malloc(F * sizeof(float)); g_logits = (float*)malloc((long long)cfg.vocab_size * sizeof(float)); srand((unsigned)time(NULL)); printf("READY\n"); fflush(stdout); std::string line; while (std::getline(std::cin, line)) { if (!line.empty() && line.back() == '\r') line.pop_back(); if (line.empty()) continue; if (line == "QUIT") break; else if (line.rfind("RESET|", 0) == 0) handle_reset(line); else if (line.rfind("REQUEST|", 0) == 0) handle_request(line); else { printf("ERROR unknown_cmd\n"); fflush(stdout); } } for (auto& kv : g_sessions) free_sess(kv.second); free(g_data); free(g_x); free(g_xb); free(g_q); free(g_k); free(g_v); free(g_attn); free(g_ff_gate); free(g_ff_up); free(g_logits); return 0; }