/* Step 1: KV-cached inference. * * Build: gcc -O3 -march=native -o infer_kv infer_kv.c * Run: ./infer_kv "" * * Adds a per-layer K/V cache that grows one position at a time. Per-token cost * drops from O(T²) to O(T) in attention and becomes constant in FFN. * All arithmetic still integer/boolean; SIMD added in a later step. */ #include #include #include #include #include #include #include typedef uint64_t u64; typedef uint32_t u32; typedef int32_t i32; typedef int64_t i64; typedef uint8_t u8; #define MAGIC_BIT1 0x31544942u typedef struct { u32 vocab_size, d_model, n_layers, n_heads, d_ff, max_seq_len; i64 logit_scale_M; u32 head_dim, words_d, words_ff, words_head; } Config; typedef struct { u64 *weight_bits; i32 *threshold; u32 in_features, out_features, words_in; } BitLinear; typedef struct { i32 *alibi_slopes; BitLinear q, k, v, o; } Attention; typedef struct { BitLinear gate, up, down; } FFN; typedef struct { Attention attn; FFN ffn; /* Per-layer KV cache. Bits stored compact per position for K and V. */ u64 *k_cache; /* (max_seq_len, words_d) — but we only use first d_model bits per row */ u64 *v_cache; /* same */ } Layer; typedef struct { Config cfg; u64 *embed_bits; Layer *layers; u64 *out_codebook_bits; i64 *int_out_bias; } Model; /* ----------------- file I/O ----------------- */ static void must_read(void *ptr, size_t n, FILE *f, const char *what) { if (fread(ptr, 1, n, f) != n) { fprintf(stderr, "short read %s\n", what); exit(1); } } static void read_bitlinear(FILE *f, BitLinear *bl, u32 in_features, u32 out_features) { bl->in_features = in_features; bl->out_features = out_features; bl->words_in = (in_features + 63) / 64; size_t wb = (size_t)out_features * bl->words_in * sizeof(u64); bl->weight_bits = (u64 *)malloc(wb); bl->threshold = (i32 *)malloc(out_features * sizeof(i32)); must_read(bl->weight_bits, wb, f, "bl.weights"); must_read(bl->threshold, out_features * sizeof(i32), f, "bl.thr"); } static void load_model(const char *path, Model *m) { FILE *f = fopen(path, "rb"); if (!f) { perror(path); exit(1); } u32 header[8]; must_read(header, sizeof(header), f, "header"); if (header[0] != MAGIC_BIT1) { fprintf(stderr, "bad magic\n"); exit(1); } Config *c = &m->cfg; c->vocab_size = header[2]; c->d_model = header[3]; c->n_layers = header[4]; c->n_heads = header[5]; c->d_ff = header[6]; c->max_seq_len = header[7]; must_read(&c->logit_scale_M, sizeof(i64), f, "M"); c->head_dim = c->d_model / c->n_heads; c->words_d = (c->d_model + 63) / 64; c->words_ff = (c->d_ff + 63) / 64; c->words_head = (c->head_dim + 63) / 64; size_t eb = (size_t)c->vocab_size * c->words_d * sizeof(u64); m->embed_bits = (u64 *)malloc(eb); must_read(m->embed_bits, eb, f, "embed"); m->layers = (Layer *)calloc(c->n_layers, sizeof(Layer)); for (u32 l = 0; l < c->n_layers; l++) { Layer *ly = &m->layers[l]; ly->attn.alibi_slopes = (i32 *)malloc(c->n_heads * sizeof(i32)); must_read(ly->attn.alibi_slopes, c->n_heads * sizeof(i32), f, "alibi"); read_bitlinear(f, &ly->attn.q, c->d_model, c->d_model); read_bitlinear(f, &ly->attn.k, c->d_model, c->d_model); read_bitlinear(f, &ly->attn.v, c->d_model, c->d_model); read_bitlinear(f, &ly->attn.o, c->d_model, c->d_model); read_bitlinear(f, &ly->ffn.gate, c->d_model, c->d_ff); read_bitlinear(f, &ly->ffn.up, c->d_model, c->d_ff); read_bitlinear(f, &ly->ffn.down, c->d_ff, c->d_model); /* Allocate per-layer KV caches. */ size_t kv_sz = (size_t)c->max_seq_len * c->words_d * sizeof(u64); ly->k_cache = (u64 *)calloc((size_t)c->max_seq_len * c->words_d, sizeof(u64)); ly->v_cache = (u64 *)calloc((size_t)c->max_seq_len * c->words_d, sizeof(u64)); (void)kv_sz; } m->out_codebook_bits = (u64 *)malloc(eb); must_read(m->out_codebook_bits, eb, f, "out_codebook"); m->int_out_bias = (i64 *)malloc(c->vocab_size * sizeof(i64)); must_read(m->int_out_bias, c->vocab_size * sizeof(i64), f, "out_bias"); fclose(f); } /* ----------------- primitives ----------------- */ static inline i32 bipolar_dot(const u64 *a, const u64 *b, u32 words, u32 in_features) { i64 agree = 0; for (u32 w = 0; w < words; w++) agree += __builtin_popcountll(~(a[w] ^ b[w])); u32 pad = (words * 64) - in_features; agree -= pad; return (i32)(2 * agree - in_features); } static void bitlinear_forward(const BitLinear *bl, const u64 *x_bits, u64 *out_bits) { u32 words_out = (bl->out_features + 63) / 64; memset(out_bits, 0, words_out * sizeof(u64)); for (u32 i = 0; i < bl->out_features; i++) { const u64 *w_row = bl->weight_bits + (size_t)i * bl->words_in; i32 y = bipolar_dot(w_row, x_bits, bl->words_in, bl->in_features); if (y >= bl->threshold[i]) out_bits[i / 64] |= ((u64)1) << (i % 64); } } static inline void extract_head(const u64 *x_bits, u32 head_dim, u32 h, u64 *head_bits) { u32 start_bit = h * head_dim; u32 words = (head_dim + 63) / 64; for (u32 w = 0; w < words; w++) { u32 bit_start = start_bit + w * 64; u32 lo_word = bit_start / 64; u32 shift = bit_start % 64; u64 v = x_bits[lo_word] >> shift; if (shift && (lo_word + 1) * 64 < start_bit + head_dim) v |= x_bits[lo_word + 1] << (64 - shift); u32 remaining = (w + 1) * 64 <= head_dim ? 64 : (head_dim - w * 64); if (remaining < 64) v &= (((u64)1 << remaining) - 1); head_bits[w] = v; } } static inline void majority3(const u64 *a, const u64 *b, const u64 *c, u64 *out, u32 words) { for (u32 w = 0; w < words; w++) out[w] = (a[w] & b[w]) | (a[w] & c[w]) | (b[w] & c[w]); } /* ----------------- KV-cached forward step ----------------- * Given a token at position t (0-indexed), fills K/V cache at that position and * advances x through all layers. Returns argmax next-token logit over the vocab. */ static u32 step_token(Model *m, u32 token_id, u32 t) { const Config *c = &m->cfg; u32 wd = c->words_d, wf = c->words_ff, wh = c->words_head; /* Scratch buffers sized for d_model=256: 4 u64 per vector. */ u64 x[16]; u64 q[16]; u64 k[16]; u64 v[16]; u64 a_bits[16]; u64 o_bits[16]; u64 g_bits[16]; u64 u_bits[16]; u64 h_bits[16]; u64 f_bits[16]; u64 new_x[16]; u64 q_head[8]; u64 k_head[8]; /* Embed */ memcpy(x, m->embed_bits + (size_t)token_id * wd, wd * sizeof(u64)); for (u32 li = 0; li < c->n_layers; li++) { Layer *ly = &m->layers[li]; /* Project Q, K, V */ bitlinear_forward(&ly->attn.q, x, q); bitlinear_forward(&ly->attn.k, x, k); bitlinear_forward(&ly->attn.v, x, v); /* Cache K, V at position t. */ memcpy(ly->k_cache + (size_t)t * wd, k, wd * sizeof(u64)); memcpy(ly->v_cache + (size_t)t * wd, v, wd * sizeof(u64)); /* Attention: for each head, argmax over keys 0..t, then gather V. */ memset(a_bits, 0, wd * sizeof(u64)); for (u32 h = 0; h < c->n_heads; h++) { extract_head(q, c->head_dim, h, q_head); i32 best_score = INT32_MIN; u32 best_j = 0; for (u32 j = 0; j <= t; j++) { const u64 *k_j = ly->k_cache + (size_t)j * wd; extract_head(k_j, c->head_dim, h, k_head); i32 s = bipolar_dot(q_head, k_head, wh, c->head_dim); i32 d = (i32)t - (i32)j; if (d < 0) d = -d; s -= ly->attn.alibi_slopes[h] * d; if (s > best_score) { best_score = s; best_j = j; } } const u64 *v_bits = ly->v_cache + (size_t)best_j * wd; for (u32 bit = 0; bit < c->head_dim; bit++) { u32 src_bit = h * c->head_dim + bit; u64 vv = (v_bits[src_bit / 64] >> (src_bit % 64)) & 1ULL; a_bits[src_bit / 64] |= vv << (src_bit % 64); } } /* Output projection */ bitlinear_forward(&ly->attn.o, a_bits, o_bits); /* FFN */ bitlinear_forward(&ly->ffn.gate, x, g_bits); bitlinear_forward(&ly->ffn.up, x, u_bits); for (u32 w = 0; w < wf; w++) h_bits[w] = ~(g_bits[w] ^ u_bits[w]); bitlinear_forward(&ly->ffn.down, h_bits, f_bits); /* Residual majority */ majority3(x, o_bits, f_bits, new_x, wd); memcpy(x, new_x, wd * sizeof(u64)); } /* Output head */ i64 best_logit = INT64_MIN; u32 best_v = 0; for (u32 vid = 0; vid < c->vocab_size; vid++) { const u64 *vec = m->out_codebook_bits + (size_t)vid * wd; i32 dot = bipolar_dot(vec, x, wd, c->d_model); i64 logit = (i64)dot * c->logit_scale_M + m->int_out_bias[vid]; if (logit > best_logit) { best_logit = logit; best_v = vid; } } return best_v; } static double now_ms(void) { struct timespec ts; clock_gettime(CLOCK_MONOTONIC, &ts); return ts.tv_sec * 1000.0 + ts.tv_nsec * 1e-6; } /* ----------------- main ----------------- */ int main(int argc, char **argv) { if (argc < 4) { fprintf(stderr, "usage: %s \"\" \n", argv[0]); return 2; } const char *bin = argv[1]; const char *prompt = argv[2]; u32 n_new = (u32)atoi(argv[3]); Model m = {0}; load_model(bin, &m); fprintf(stderr, "loaded: vocab=%u d=%u L=%u H=%u ff=%u Tmax=%u\n", m.cfg.vocab_size, m.cfg.d_model, m.cfg.n_layers, m.cfg.n_heads, m.cfg.d_ff, m.cfg.max_seq_len); u32 prompt_len = (u32)strlen(prompt); if (prompt_len == 0 || prompt_len + n_new > m.cfg.max_seq_len) { fprintf(stderr, "prompt_len=%u + n_new=%u must be <= %u\n", prompt_len, n_new, m.cfg.max_seq_len); return 2; } /* Prefill: step through the prompt, ignoring all but the last prediction. */ double t_prefill_start = now_ms(); u32 next_id = 0; for (u32 t = 0; t < prompt_len; t++) { next_id = step_token(&m, (u32)(u8)prompt[t], t); } double t_prefill_ms = now_ms() - t_prefill_start; /* Emit prompt */ fwrite(prompt, 1, prompt_len, stdout); /* Generate */ double t_gen_start = now_ms(); for (u32 step = 0; step < n_new; step++) { putchar((int)next_id); fflush(stdout); u32 pos = prompt_len + step; if (pos >= m.cfg.max_seq_len) break; next_id = step_token(&m, next_id, pos); } double t_gen_ms = now_ms() - t_gen_start; putchar('\n'); fprintf(stderr, "prefill: %u tok in %.1f ms (%.1f tok/s)\n" "generate: %u tok in %.1f ms (%.1f tok/s)\n", prompt_len, t_prefill_ms, prompt_len * 1000.0 / t_prefill_ms, n_new, t_gen_ms, n_new * 1000.0 / t_gen_ms); return 0; }