LLaMA.Cpp / inference.cpp
NOT-OMEGA's picture
Update inference.cpp
1c1ef4a verified
/*
* ============================================================
* KVInfer β€” Llama 1B Inference Engine v1.0
* ============================================================
*
* Architecture:
* RMSNorm Β· RoPE Β· GQA (n_kv_head != n_head) Β· SwiGLU MLP
* AVX2 + FMA matmul Β· OpenMP parallelism Β· KV-Cache
*
* ── STDIN PROTOCOL ──────────────────────────────────────────
* REQUEST|<sess>|<new_tokens_csv>|<max_new>|<temp>|<top_k>|<stop_csv>
* RESET|<sess>
* QUIT
*
* ── STDOUT PROTOCOL ─────────────────────────────────────────
* READY
* TOKEN <id> <elapsed_ms>
* DONE <count> <total_ms>
* RESET_OK
* ERROR <message>
*
* ── COMPILE ─────────────────────────────────────────────────
* g++ -O3 -march=native -fopenmp -ffast-math -std=c++17 \
* -o inference inference.cpp -lm
* ============================================================
*/
#include <cstdio>
#include <cstdlib>
#include <cmath>
#include <cstring>
#include <ctime>
#include <cassert>
#include <algorithm>
#include <string>
#include <vector>
#include <iostream>
#include <unordered_map>
#include <unordered_set>
#include <immintrin.h>
#ifdef _OPENMP
#include <omp.h>
#endif
#ifdef _WIN32
#include <windows.h>
static double get_ms(){LARGE_INTEGER f,c;QueryPerformanceFrequency(&f);QueryPerformanceCounter(&c);return(double)c.QuadPart/f.QuadPart*1000.0;}
#else
#include <sys/time.h>
static double get_ms(){struct timeval tv;gettimeofday(&tv,NULL);return tv.tv_sec*1000.0+tv.tv_usec/1000.0;}
#endif
// ─────────────────────────────────────────────────────────────────────────
// Config (filled from binary header)
// ─────────────────────────────────────────────────────────────────────────
struct Config {
int n_layer, n_head, n_kv_head, n_embd, n_intermediate, vocab_size, max_seq_len;
float rope_theta;
};
// ─────────────────────────────────────────────────────────────────────────
// Weights (pointers into mmap'd / malloc'd buffer)
// ─────────────────────────────────────────────────────────────────────────
struct Weights {
float* embed_tokens; // [vocab_size, n_embd]
// Per-layer arrays (size n_layer each)
float** rms_att; // [n_embd]
float** q_proj; // [n_head * hs, n_embd]
float** k_proj; // [n_kv_head * hs, n_embd]
float** v_proj; // [n_kv_head * hs, n_embd]
float** o_proj; // [n_embd, n_head * hs]
float** rms_ffn; // [n_embd]
float** gate_proj; // [n_intermediate, n_embd]
float** up_proj; // [n_intermediate, n_embd]
float** down_proj; // [n_embd, n_intermediate]
float* rms_final; // [n_embd]
float* lm_head; // [vocab_size, n_embd]
};
static Config cfg;
static Weights W;
static float* g_data = nullptr;
// ─────────────────────────────────────────────────────────────────────────
// Session (per-session KV cache + position)
// ─────────────────────────────────────────────────────────────────────────
struct Session {
float* k_cache = nullptr; // [n_layer, max_seq_len, n_kv_head * hs]
float* v_cache = nullptr;
int pos = 0;
double last_use = 0.0;
};
static const int MAX_SESSIONS = 8;
static std::unordered_map<std::string, Session> g_sessions;
// ─────────────────────────────────────────────────────────────────────────
// Working buffers (shared across requests, single-threaded per request)
// ─────────────────────────────────────────────────────────────────────────
static float *g_x, *g_xb, *g_q, *g_k, *g_v, *g_attn, *g_ff_gate, *g_ff_up, *g_logits;
// ─────────────────────────────────────────────────────────────────────────
// Math Kernels
// ─────────────────────────────────────────────────────────────────────────
// RMSNorm: out[i] = x[i] / rms(x) * w[i]
static void rmsnorm(float* out, const float* x, const float* w, int N) {
float ss = 0.0f;
for (int i = 0; i < N; i++) ss += x[i] * x[i];
ss = 1.0f / sqrtf(ss / N + 1e-5f);
for (int i = 0; i < N; i++) out[i] = x[i] * ss * w[i];
}
// AVX2 + FMA matrix-vector multiply: out[M] = mat[M,K] * x[K]
static void matmul(float* out, const float* mat, const float* x, int M, int K) {
#pragma omp parallel for schedule(static)
for (int i = 0; i < M; i++) {
const float* row = mat + (long long)i * K;
__m256 acc = _mm256_setzero_ps();
int j = 0;
for (; j <= K - 8; j += 8)
acc = _mm256_fmadd_ps(_mm256_loadu_ps(row + j),
_mm256_loadu_ps(x + j), acc);
float tmp[8]; _mm256_storeu_ps(tmp, acc);
float s = tmp[0]+tmp[1]+tmp[2]+tmp[3]+tmp[4]+tmp[5]+tmp[6]+tmp[7];
for (; j < K; j++) s += row[j] * x[j];
out[i] = s;
}
}
// SwiGLU: out[i] = silu(gate[i]) * up[i]
// silu(x) = x * sigmoid(x) = x / (1 + exp(-x))
static void swiglu(float* out, const float* gate, const float* up, int N) {
#pragma omp parallel for
for (int i = 0; i < N; i++) {
float g = gate[i];
float silu = g / (1.0f + expf(-g));
out[i] = silu * up[i];
}
}
// Softmax in-place over first n elements
static void softmax(float* x, int n) {
float mx = x[0];
for (int i = 1; i < n; i++) if (x[i] > mx) mx = x[i];
float s = 0.0f;
for (int i = 0; i < n; i++) { x[i] = expf(x[i] - mx); s += x[i]; }
for (int i = 0; i < n; i++) x[i] /= s;
}
// ─────────────────────────────────────────────────────────────────────────
// RoPE (Rotary Position Embedding)
// Apply in-place to a query/key vector of length dim (for one head)
// ─────────────────────────────────────────────────────────────────────────
static void rope(float* x, int pos, int dim, float theta) {
for (int i = 0; i < dim; i += 2) {
float freq = 1.0f / powf(theta, (float)i / dim);
float angle = pos * freq;
float c = cosf(angle), s = sinf(angle);
float x0 = x[i], x1 = x[i + 1];
x[i] = x0 * c - x1 * s;
x[i + 1] = x0 * s + x1 * c;
}
}
// ─────────────────────────────────────────────────────────────────────────
// Forward (single token at position pos)
// ─────────────────────────────────────────────────────────────────────────
static void forward(int token_id, int pos, float* k_cache, float* v_cache) {
const int C = cfg.n_embd;
const int H = cfg.n_head;
const int KVH = cfg.n_kv_head;
const int hs = C / H; // head dim = 64
const int kv_dim = KVH * hs; // KV dim = 512
const int GRP = H / KVH; // heads per KV group = 4
// Token embedding
memcpy(g_x, W.embed_tokens + (long long)token_id * C, C * sizeof(float));
for (int l = 0; l < cfg.n_layer; l++) {
// ── Attention ─────────────────────────────────────────────────────
// Input RMSNorm
rmsnorm(g_xb, g_x, W.rms_att[l], C);
// Q, K, V projections (no bias in Llama)
matmul(g_q, W.q_proj[l], g_xb, C, C); // [H*hs]
matmul(g_k, W.k_proj[l], g_xb, kv_dim, C); // [KVH*hs]
matmul(g_v, W.v_proj[l], g_xb, kv_dim, C); // [KVH*hs]
// RoPE on Q and K (per-head)
for (int h = 0; h < H; h++) rope(g_q + h*hs, pos, hs, cfg.rope_theta);
for (int h = 0; h < KVH; h++) rope(g_k + h*hs, pos, hs, cfg.rope_theta);
// Store K, V into cache for this layer
float* kc = k_cache + (long long)l * cfg.max_seq_len * kv_dim;
float* vc = v_cache + (long long)l * cfg.max_seq_len * kv_dim;
memcpy(kc + (long long)pos * kv_dim, g_k, kv_dim * sizeof(float));
memcpy(vc + (long long)pos * kv_dim, g_v, kv_dim * sizeof(float));
// GQA Attention: for each Q head, attend to its KV group
#pragma omp parallel for schedule(static)
for (int h = 0; h < H; h++) {
int kv_h = h / GRP; // which KV head
float scale = 1.0f / sqrtf((float)hs);
float* qh = g_q + h * hs;
// Scores
for (int t = 0; t <= pos; t++) {
float* kh = kc + (long long)t * kv_dim + kv_h * hs;
float dot = 0.0f;
for (int d = 0; d < hs; d++) dot += qh[d] * kh[d];
g_attn[h * cfg.max_seq_len + t] = dot * scale;
}
softmax(g_attn + h * cfg.max_seq_len, pos + 1);
// Weighted sum of V
float* out_h = g_xb + h * hs;
memset(out_h, 0, hs * sizeof(float));
for (int t = 0; t <= pos; t++) {
float* vh = vc + (long long)t * kv_dim + kv_h * hs;
float a = g_attn[h * cfg.max_seq_len + t];
for (int d = 0; d < hs; d++) out_h[d] += a * vh[d];
}
}
// O projection + residual
float tmp_o[C]; // stack β€” ok for C = 2048
matmul(tmp_o, W.o_proj[l], g_xb, C, C);
#pragma omp parallel for
for (int i = 0; i < C; i++) g_x[i] += tmp_o[i];
// ── MLP (SwiGLU) ──────────────────────────────────────────────────
rmsnorm(g_xb, g_x, W.rms_ffn[l], C);
// gate and up projections in parallel
matmul(g_ff_gate, W.gate_proj[l], g_xb, cfg.n_intermediate, C);
matmul(g_ff_up, W.up_proj[l], g_xb, cfg.n_intermediate, C);
// SwiGLU activation: ff = silu(gate) * up
swiglu(g_ff_gate, g_ff_gate, g_ff_up, cfg.n_intermediate);
// Down projection + residual
float tmp_d[C];
matmul(tmp_d, W.down_proj[l], g_ff_gate, C, cfg.n_intermediate);
#pragma omp parallel for
for (int i = 0; i < C; i++) g_x[i] += tmp_d[i];
}
// Final RMSNorm + LM head
rmsnorm(g_xb, g_x, W.rms_final, C);
matmul(g_logits, W.lm_head, g_xb, cfg.vocab_size, C);
}
// ─────────────────────────────────────────────────────────────────────────
// Weight Mapping
// ─────────────────────────────────────────────────────────────────────────
static void map_weights(float* data) {
const int C = cfg.n_embd;
const int L = cfg.n_layer;
const int KVH = cfg.n_kv_head;
const int hs = C / cfg.n_head;
const int kv_dim = KVH * hs;
const int F = cfg.n_intermediate;
float* p = data;
W.embed_tokens = p; p += (long long)cfg.vocab_size * C;
#define MK(f) W.f = (float**)malloc(L * sizeof(float*))
MK(rms_att); MK(q_proj); MK(k_proj); MK(v_proj); MK(o_proj);
MK(rms_ffn); MK(gate_proj); MK(up_proj); MK(down_proj);
#undef MK
for (int l = 0; l < L; l++) {
W.rms_att[l] = p; p += C;
W.q_proj[l] = p; p += (long long)C * C;
W.k_proj[l] = p; p += (long long)kv_dim * C;
W.v_proj[l] = p; p += (long long)kv_dim * C;
W.o_proj[l] = p; p += (long long)C * C;
W.rms_ffn[l] = p; p += C;
W.gate_proj[l] = p; p += (long long)F * C;
W.up_proj[l] = p; p += (long long)F * C;
W.down_proj[l] = p; p += (long long)C * F;
}
W.rms_final = p; p += C;
W.lm_head = p;
}
// ─────────────────────────────────────────────────────────────────────────
// Session Management (LRU evict)
// ─────────────────────────────────────────────────────────────────────────
static long long kv_bytes() {
int kv_dim = cfg.n_kv_head * (cfg.n_embd / cfg.n_head);
return (long long)cfg.n_layer * cfg.max_seq_len * kv_dim * sizeof(float);
}
static void free_sess(Session& s) {
free(s.k_cache); free(s.v_cache);
s.k_cache = nullptr; s.v_cache = nullptr; s.pos = 0;
}
static void evict_oldest() {
if (g_sessions.empty()) return;
std::string oid; double ot = 1e300;
for (auto& kv : g_sessions)
if (kv.second.last_use < ot) { ot = kv.second.last_use; oid = kv.first; }
free_sess(g_sessions[oid]);
g_sessions.erase(oid);
}
static Session& get_or_create(const std::string& id) {
auto it = g_sessions.find(id);
if (it != g_sessions.end()) { it->second.last_use = get_ms(); return it->second; }
if ((int)g_sessions.size() >= MAX_SESSIONS) evict_oldest();
Session s;
long long nb = kv_bytes();
s.k_cache = (float*)calloc(nb, 1);
s.v_cache = (float*)calloc(nb, 1);
s.pos = 0;
s.last_use = get_ms();
g_sessions[id] = s;
return g_sessions[id];
}
// ─────────────────────────────────────────────────────────────────────────
// Sampler (Top-K)
// ─────────────────────────────────────────────────────────────────────────
static int sample_topk(float temperature, int top_k) {
for (int v = 0; v < cfg.vocab_size; v++) g_logits[v] /= temperature;
int K = std::min(top_k, cfg.vocab_size);
std::vector<std::pair<float,int>> pairs(cfg.vocab_size);
for (int v = 0; v < cfg.vocab_size; v++) pairs[v] = {g_logits[v], v};
std::partial_sort(pairs.begin(), pairs.begin() + K, pairs.end(),
[](const auto& a, const auto& b){ return a.first > b.first; });
float sum = 0.0f;
for (int j = 0; j < K; j++) { pairs[j].first = expf(pairs[j].first); sum += pairs[j].first; }
for (int j = 0; j < K; j++) pairs[j].first /= sum;
float r = (float)rand() / ((float)RAND_MAX + 1.0f), cum = 0.0f, best = pairs[0].second;
for (int j = 0; j < K; j++) { cum += pairs[j].first; if (r < cum) { best = pairs[j].second; break; } }
return (int)best;
}
// ─────────────────────────────────────────────────────────────────────────
// Helpers
// ─────────────────────────────────────────────────────────────────────────
static std::vector<std::string> split(const std::string& s, char d) {
std::vector<std::string> out; std::string cur;
for (char c : s) { if (c == d) { out.push_back(cur); cur.clear(); } else cur += c; }
out.push_back(cur); return out;
}
static std::vector<int> parse_ints(const std::string& s) {
std::vector<int> out;
for (auto& t : split(s, ',')) if (!t.empty()) out.push_back(atoi(t.c_str()));
return out;
}
// ─────────────────────────────────────────────────────────────────────────
// Command Handlers
// ─────────────────────────────────────────────────────────────────────────
static void handle_request(const std::string& line) {
auto parts = split(line, '|');
if (parts.size() < 7) { printf("ERROR bad_format\n"); fflush(stdout); return; }
std::string sess_id = parts[1];
auto new_toks = parse_ints(parts[2]);
int max_new = atoi(parts[3].c_str());
float temp = (float)atof(parts[4].c_str());
int top_k = atoi(parts[5].c_str());
auto stop_lst = parse_ints(parts[6]);
temp = std::max(temp, 0.01f);
top_k = std::clamp(top_k, 1, cfg.vocab_size);
max_new = std::max(max_new, 1);
std::unordered_set<int> stop_ids(stop_lst.begin(), stop_lst.end());
stop_ids.insert(128009); // <|eot_id|> Llama 3 EOS
stop_ids.insert(128001); // <|end_of_text|>
Session& sess = get_or_create(sess_id);
// Prefill
for (int tok : new_toks) {
if (sess.pos >= cfg.max_seq_len) {
printf("ERROR context_full\n"); fflush(stdout); return;
}
forward(tok, sess.pos, sess.k_cache, sess.v_cache);
sess.pos++;
}
// Autoregressive generation
double t0 = get_ms();
int gen = 0;
for (int i = 0; i < max_new; i++) {
if (sess.pos >= cfg.max_seq_len) break;
int next = sample_topk(temp, top_k);
printf("TOKEN %d %.2f\n", next, get_ms() - t0);
fflush(stdout);
gen++;
if (stop_ids.count(next)) break;
forward(next, sess.pos, sess.k_cache, sess.v_cache);
sess.pos++;
}
printf("DONE %d %.2f\n", gen, get_ms() - t0);
fflush(stdout);
}
static void handle_reset(const std::string& line) {
auto parts = split(line, '|');
if (parts.size() >= 2) {
auto it = g_sessions.find(parts[1]);
if (it != g_sessions.end()) { free_sess(it->second); g_sessions.erase(it); }
}
printf("RESET_OK\n"); fflush(stdout);
}
// ─────────────────────────────────────────────────────────────────────────
// main
// ─────────────────────────────────────────────────────────────────────────
int main() {
FILE* f = fopen("model_llama.bin", "rb");
if (!f) { printf("ERROR model_llama.bin not found\n"); fflush(stdout); return 1; }
// Read header
int hdr[7]; float theta;
fread(hdr, sizeof(int), 7, f);
fread(&theta, sizeof(float), 1, f);
cfg = {hdr[0], hdr[1], hdr[2], hdr[3], hdr[4], hdr[5], hdr[6], theta};
printf("[engine] Layers=%d Heads=%d KVHeads=%d Embd=%d Inter=%d Vocab=%d Seq=%d Theta=%.0f\n",
cfg.n_layer, cfg.n_head, cfg.n_kv_head, cfg.n_embd,
cfg.n_intermediate, cfg.vocab_size, cfg.max_seq_len, cfg.rope_theta);
fflush(stdout);
// Load weights
fseek(f, 0, SEEK_END); long fsize = ftell(f);
long woff = 7 * sizeof(int) + sizeof(float);
fseek(f, woff, SEEK_SET);
long wbytes = fsize - woff;
g_data = (float*)malloc(wbytes);
if (!g_data) { printf("ERROR oom\n"); fflush(stdout); return 1; }
fread(g_data, 1, wbytes, f);
fclose(f);
map_weights(g_data);
// Working buffers
const int C = cfg.n_embd;
const int F = cfg.n_intermediate;
const int S = cfg.max_seq_len;
const int H = cfg.n_head;
g_x = (float*)malloc(C * sizeof(float));
g_xb = (float*)malloc(C * sizeof(float));
g_q = (float*)malloc(C * sizeof(float)); // H * hs = C
g_k = (float*)malloc(cfg.n_kv_head * (C/H) * sizeof(float));
g_v = (float*)malloc(cfg.n_kv_head * (C/H) * sizeof(float));
g_attn = (float*)malloc((long long)H * S * sizeof(float));
g_ff_gate = (float*)malloc(F * sizeof(float));
g_ff_up = (float*)malloc(F * sizeof(float));
g_logits = (float*)malloc((long long)cfg.vocab_size * sizeof(float));
srand((unsigned)time(NULL));
printf("READY\n"); fflush(stdout);
std::string line;
while (std::getline(std::cin, line)) {
if (!line.empty() && line.back() == '\r') line.pop_back();
if (line.empty()) continue;
if (line == "QUIT") break;
else if (line.rfind("RESET|", 0) == 0) handle_reset(line);
else if (line.rfind("REQUEST|", 0) == 0) handle_request(line);
else { printf("ERROR unknown_cmd\n"); fflush(stdout); }
}
for (auto& kv : g_sessions) free_sess(kv.second);
free(g_data);
free(g_x); free(g_xb); free(g_q); free(g_k); free(g_v);
free(g_attn); free(g_ff_gate); free(g_ff_up); free(g_logits);
return 0;
}