hidude562's picture
1bitllm code (checkpoints to follow)
4754707 verified
/* Pure-integer inference for v18 exported checkpoint.
*
* Build: gcc -O3 -march=native -o infer infer.c
* Run: ./infer <weights.bin> "<prompt>" <num_new_tokens>
*
* No floating-point arithmetic on the inference hot path.
* Ops used: XNOR + popcount (__builtin_popcountll), integer add/sub/compare,
* indexed memory read (gather).
*/
#include <assert.h>
#include <inttypes.h>
#include <stdint.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
typedef uint64_t u64;
typedef uint32_t u32;
typedef int32_t i32;
typedef int64_t i64;
typedef uint8_t u8;
#define MAGIC_BIT1 0x31544942u /* 'BIT1' little-endian */
/* --- model configuration (read from header) --- */
typedef struct {
u32 vocab_size;
u32 d_model;
u32 n_layers;
u32 n_heads;
u32 d_ff;
u32 max_seq_len;
i64 logit_scale_M;
u32 head_dim; /* d_model / n_heads */
u32 words_d; /* ceil(d_model / 64) */
u32 words_ff; /* ceil(d_ff / 64) */
u32 words_head; /* ceil(head_dim / 64) */
} Config;
/* --- packed binary weights + integer thresholds per BitLinear --- */
typedef struct {
u64 *weight_bits; /* out_rows * words_in */
i32 *threshold; /* out_rows */
u32 in_features;
u32 out_features;
u32 words_in;
} BitLinear;
typedef struct {
i32 *alibi_slopes; /* n_heads */
BitLinear q, k, v, o;
} Attention;
typedef struct {
BitLinear gate, up, down;
} FFN;
typedef struct {
Attention attn;
FFN ffn;
} Layer;
typedef struct {
Config cfg;
u64 *embed_bits; /* vocab * words_d */
Layer *layers; /* n_layers */
u64 *out_codebook_bits; /* vocab * words_d */
i64 *int_out_bias; /* vocab */
} Model;
/* --- Forward activation buffers (bit-packed) --- */
typedef struct {
u64 *x; /* (seq_len, words_d) current hidden state per position */
u64 *q_all; /* (seq_len, words_d) Q projections per position */
u64 *k_all; /* (seq_len, words_d) K projections per position */
u64 *v_all; /* (seq_len, words_d) V projections per position */
u64 *a_bits; /* (words_d) attention output per position (tmp) */
u64 *f_bits; /* (words_d) FFN output per position (tmp) */
u64 *g_bits; /* (words_ff) FFN gate(x) tmp */
u64 *u_bits; /* (words_ff) FFN up(x) tmp */
u64 *h_bits; /* (words_ff) g XNOR u tmp */
i32 *scores; /* (n_heads, seq_len) attention scores per (head, key) for current query */
} Buffers;
/* ----------------- file reading ----------------- */
static void must_read(void *ptr, size_t n, FILE *f, const char *what) {
if (fread(ptr, 1, n, f) != n) {
fprintf(stderr, "short read on %s\n", what);
exit(1);
}
}
static void read_bitlinear(FILE *f, BitLinear *bl, u32 in_features, u32 out_features) {
bl->in_features = in_features;
bl->out_features = out_features;
bl->words_in = (in_features + 63) / 64;
size_t wb = (size_t)out_features * bl->words_in * sizeof(u64);
bl->weight_bits = (u64 *)malloc(wb);
bl->threshold = (i32 *)malloc(out_features * sizeof(i32));
must_read(bl->weight_bits, wb, f, "BitLinear.weight_bits");
must_read(bl->threshold, out_features * sizeof(i32), f, "BitLinear.threshold");
}
static void load_model(const char *path, Model *m) {
FILE *f = fopen(path, "rb");
if (!f) { perror(path); exit(1); }
u32 header[8];
must_read(header, sizeof(header), f, "header");
if (header[0] != MAGIC_BIT1) {
fprintf(stderr, "bad magic 0x%08x (want 0x%08x)\n", header[0], MAGIC_BIT1);
exit(1);
}
if (header[1] != 1) { fprintf(stderr, "bad version %u\n", header[1]); exit(1); }
Config *c = &m->cfg;
c->vocab_size = header[2];
c->d_model = header[3];
c->n_layers = header[4];
c->n_heads = header[5];
c->d_ff = header[6];
c->max_seq_len = header[7];
must_read(&c->logit_scale_M, sizeof(i64), f, "logit_scale_M");
c->head_dim = c->d_model / c->n_heads;
c->words_d = (c->d_model + 63) / 64;
c->words_ff = (c->d_ff + 63) / 64;
c->words_head = (c->head_dim + 63) / 64;
/* Embedding */
size_t eb = (size_t)c->vocab_size * c->words_d * sizeof(u64);
m->embed_bits = (u64 *)malloc(eb);
must_read(m->embed_bits, eb, f, "embedding");
/* Layers */
m->layers = (Layer *)calloc(c->n_layers, sizeof(Layer));
for (u32 l = 0; l < c->n_layers; l++) {
Layer *ly = &m->layers[l];
ly->attn.alibi_slopes = (i32 *)malloc(c->n_heads * sizeof(i32));
must_read(ly->attn.alibi_slopes, c->n_heads * sizeof(i32), f, "alibi_slopes");
read_bitlinear(f, &ly->attn.q, c->d_model, c->d_model);
read_bitlinear(f, &ly->attn.k, c->d_model, c->d_model);
read_bitlinear(f, &ly->attn.v, c->d_model, c->d_model);
read_bitlinear(f, &ly->attn.o, c->d_model, c->d_model);
read_bitlinear(f, &ly->ffn.gate, c->d_model, c->d_ff);
read_bitlinear(f, &ly->ffn.up, c->d_model, c->d_ff);
read_bitlinear(f, &ly->ffn.down, c->d_ff, c->d_model);
}
/* Output head */
m->out_codebook_bits = (u64 *)malloc(eb);
must_read(m->out_codebook_bits, eb, f, "out_codebook");
m->int_out_bias = (i64 *)malloc(c->vocab_size * sizeof(i64));
must_read(m->int_out_bias, c->vocab_size * sizeof(i64), f, "int_out_bias");
/* Make sure we reached EOF */
u8 tail;
size_t got = fread(&tail, 1, 1, f);
if (got != 0) {
fprintf(stderr, "warning: extra bytes after expected EOF\n");
}
fclose(f);
}
/* ----------------- buffers ----------------- */
static Buffers alloc_buffers(const Config *c) {
Buffers b = {0};
size_t wd = c->words_d;
size_t wf = c->words_ff;
b.x = (u64 *)calloc((size_t)c->max_seq_len * wd, sizeof(u64));
b.q_all = (u64 *)calloc((size_t)c->max_seq_len * wd, sizeof(u64));
b.k_all = (u64 *)calloc((size_t)c->max_seq_len * wd, sizeof(u64));
b.v_all = (u64 *)calloc((size_t)c->max_seq_len * wd, sizeof(u64));
b.a_bits = (u64 *)calloc(wd, sizeof(u64));
b.f_bits = (u64 *)calloc(wd, sizeof(u64));
b.g_bits = (u64 *)calloc(wf, sizeof(u64));
b.u_bits = (u64 *)calloc(wf, sizeof(u64));
b.h_bits = (u64 *)calloc(wf, sizeof(u64));
b.scores = (i32 *)calloc((size_t)c->n_heads * c->max_seq_len, sizeof(i32));
return b;
}
/* ----------------- primitive ops ----------------- */
/* XNOR-popcount dot product in bipolar ±1 space.
* y = 2 * popcount(a XNOR b) - in_features
* where bits=1 means +1, bits=0 means −1. */
static inline i32 bipolar_dot(const u64 *a, const u64 *b, u32 words, u32 in_features) {
i64 agree = 0;
for (u32 w = 0; w < words; w++) {
u64 matches = ~(a[w] ^ b[w]); /* XNOR: 1 where both same */
agree += __builtin_popcountll(matches);
}
/* If in_features % 64 != 0, trailing bits of BOTH rows are padded with 0,
which counts them as "both 0 → +1 agrees with +1". Subtract the pad. */
u32 pad = (words * 64) - in_features;
agree -= pad;
return (i32)(2 * agree - in_features);
}
/* BitLinear forward: for each output row i,
* y_i = bipolar_dot(W_i, x, words_in, in_features)
* output_bit_i = (y_i >= threshold[i])
*/
static void bitlinear_forward(const BitLinear *bl, const u64 *x_bits, u64 *out_bits) {
u32 words_out = (bl->out_features + 63) / 64;
memset(out_bits, 0, words_out * sizeof(u64));
for (u32 i = 0; i < bl->out_features; i++) {
const u64 *w_row = bl->weight_bits + (size_t)i * bl->words_in;
i32 y = bipolar_dot(w_row, x_bits, bl->words_in, bl->in_features);
if (y >= bl->threshold[i]) {
out_bits[i / 64] |= ((u64)1) << (i % 64);
}
}
}
/* Extract the per-head bit slice of a d_model-bit packed vector.
* head_bits is a pointer into some larger buffer; we write words_head words.
* Convention: head h occupies bits [h*head_dim, (h+1)*head_dim) in the d_model vector. */
static inline void extract_head(const u64 *x_bits, u32 head_dim, u32 h, u64 *head_bits) {
u32 start_bit = h * head_dim;
u32 words = (head_dim + 63) / 64;
for (u32 w = 0; w < words; w++) {
/* build a 64-bit chunk from bits [start_bit + w*64, start_bit + w*64 + 64) */
u32 bit_start = start_bit + w * 64;
u32 lo_word = bit_start / 64;
u32 shift = bit_start % 64;
u64 v = x_bits[lo_word] >> shift;
if (shift && (lo_word + 1) * 64 < start_bit + head_dim) {
v |= x_bits[lo_word + 1] << (64 - shift);
}
/* mask off bits beyond head_dim for the last word */
u32 remaining = (w + 1) * 64 <= head_dim ? 64 : (head_dim - w * 64);
if (remaining < 64) {
v &= (remaining == 64) ? ~(u64)0 : (((u64)1 << remaining) - 1);
}
head_bits[w] = v;
}
}
/* Majority-of-3 on three ±1 bit vectors. Output bit = (a + b + c) >= 2,
* which is MAJ(a,b,c) = (a & b) | (a & c) | (b & c). */
static inline void majority3(const u64 *a, const u64 *b, const u64 *c, u64 *out, u32 words) {
for (u32 w = 0; w < words; w++) {
out[w] = (a[w] & b[w]) | (a[w] & c[w]) | (b[w] & c[w]);
}
}
/* ----------------- forward ----------------- */
/* Processes the prompt up to position T-1, then greedily generates one next
* token. Uses no KV cache across calls — simple rebuild of all positions each
* step. Suitable for small seq_len. */
static u32 argmax_next_token(const Model *m, Buffers *b, const u32 *ids, u32 T) {
const Config *c = &m->cfg;
u32 wd = c->words_d;
u32 wf = c->words_ff;
/* Embed */
for (u32 t = 0; t < T; t++) {
memcpy(b->x + (size_t)t * wd,
m->embed_bits + (size_t)ids[t] * wd,
wd * sizeof(u64));
}
/* Heads scratch (allocated on the stack-ish) */
u32 wh = c->words_head;
u64 q_head[8]; /* head_dim up to 512 bits ≈ 8 words; plenty for typical configs */
u64 k_head[8];
(void)q_head; (void)k_head;
/* For each layer ... */
for (u32 li = 0; li < c->n_layers; li++) {
const Layer *ly = &m->layers[li];
/* Compute Q, K, V for every position. */
for (u32 t = 0; t < T; t++) {
const u64 *xt = b->x + (size_t)t * wd;
bitlinear_forward(&ly->attn.q, xt, b->q_all + (size_t)t * wd);
bitlinear_forward(&ly->attn.k, xt, b->k_all + (size_t)t * wd);
bitlinear_forward(&ly->attn.v, xt, b->v_all + (size_t)t * wd);
}
/* Attention + O + FFN + residual, per position. */
for (u32 t = 0; t < T; t++) {
const u64 *q_t = b->q_all + (size_t)t * wd;
/* Init attention output bits to zero for this query. */
memset(b->a_bits, 0, wd * sizeof(u64));
/* Per head: compute scores for keys ≤ t, argmax, gather V. */
for (u32 h = 0; h < c->n_heads; h++) {
/* Extract Q head slice */
extract_head(q_t, c->head_dim, h, q_head);
/* Scores over keys 0..t */
i32 best_score = INT32_MIN;
u32 best_j = 0;
for (u32 j = 0; j <= t; j++) {
const u64 *k_j = b->k_all + (size_t)j * wd;
extract_head(k_j, c->head_dim, h, k_head);
i32 s = bipolar_dot(q_head, k_head, wh, c->head_dim);
i32 dist = (i32)t - (i32)j;
if (dist < 0) dist = -dist;
s -= ly->attn.alibi_slopes[h] * dist;
if (s > best_score) {
best_score = s;
best_j = j;
}
}
/* Gather V_{best_j}, head h, into a_bits at [h*head_dim .. (h+1)*head_dim) */
const u64 *v_bits = b->v_all + (size_t)best_j * wd;
for (u32 bit = 0; bit < c->head_dim; bit++) {
u32 src_bit = h * c->head_dim + bit;
u64 v = (v_bits[src_bit / 64] >> (src_bit % 64)) & 1ULL;
u32 dst_bit = h * c->head_dim + bit;
b->a_bits[dst_bit / 64] |= v << (dst_bit % 64);
}
}
/* Output projection O(a) — overwrite a_bits in-place via temp. */
u64 a_tmp[16]; /* supports d_model up to 1024 */
bitlinear_forward(&ly->attn.o, b->a_bits, a_tmp);
/* FFN: gate, up, h = g XNOR u, down. */
const u64 *x_t = b->x + (size_t)t * wd;
bitlinear_forward(&ly->ffn.gate, x_t, b->g_bits);
bitlinear_forward(&ly->ffn.up, x_t, b->u_bits);
for (u32 w = 0; w < wf; w++) {
b->h_bits[w] = ~(b->g_bits[w] ^ b->u_bits[w]); /* XNOR gate */
}
bitlinear_forward(&ly->ffn.down, b->h_bits, b->f_bits);
/* Residual sign = majority-of-3 over (x, a_tmp, f_bits) */
u64 new_x[16];
majority3(x_t, a_tmp, b->f_bits, new_x, wd);
memcpy(b->x + (size_t)t * wd, new_x, wd * sizeof(u64));
}
}
/* Output head at the LAST position: popcount dot product with each vocab row. */
const u64 *x_last = b->x + (size_t)(T - 1) * wd;
i64 best_logit = INT64_MIN;
u32 best_v = 0;
for (u32 v = 0; v < c->vocab_size; v++) {
const u64 *vec = m->out_codebook_bits + (size_t)v * wd;
i32 dot = bipolar_dot(vec, x_last, wd, c->d_model);
i64 logit = (i64)dot * c->logit_scale_M + m->int_out_bias[v];
if (logit > best_logit) {
best_logit = logit;
best_v = v;
}
}
return best_v;
}
/* ----------------- main ----------------- */
int main(int argc, char **argv) {
if (argc < 4) {
fprintf(stderr, "usage: %s <weights.bin> \"<prompt>\" <num_new_tokens>\n", argv[0]);
return 2;
}
const char *bin_path = argv[1];
const char *prompt = argv[2];
u32 n_new = (u32)atoi(argv[3]);
Model m = {0};
load_model(bin_path, &m);
fprintf(stderr,
"loaded v18 bin: vocab=%u d_model=%u n_layers=%u n_heads=%u d_ff=%u T_max=%u M=%" PRId64 "\n",
m.cfg.vocab_size, m.cfg.d_model, m.cfg.n_layers, m.cfg.n_heads, m.cfg.d_ff,
m.cfg.max_seq_len, m.cfg.logit_scale_M);
Buffers b = alloc_buffers(&m.cfg);
/* Encode prompt as char IDs (ASCII). */
u32 prompt_len = (u32)strlen(prompt);
if (prompt_len == 0) { fprintf(stderr, "empty prompt\n"); return 2; }
if (prompt_len > m.cfg.max_seq_len) prompt_len = m.cfg.max_seq_len;
u32 *ids = (u32 *)malloc((m.cfg.max_seq_len + n_new) * sizeof(u32));
for (u32 i = 0; i < prompt_len; i++) {
u8 c = (u8)prompt[i];
ids[i] = c < m.cfg.vocab_size ? c : 32; /* fold non-ASCII to space */
}
/* Emit prompt then generate n_new tokens greedily. */
fwrite(prompt, 1, prompt_len, stdout);
u32 T = prompt_len;
for (u32 step = 0; step < n_new; step++) {
u32 next_id = argmax_next_token(&m, &b, ids, T);
putchar((int)next_id);
fflush(stdout);
if (T < m.cfg.max_seq_len) {
ids[T] = next_id;
T++;
} else {
/* slide window */
memmove(ids, ids + 1, (m.cfg.max_seq_len - 1) * sizeof(u32));
ids[m.cfg.max_seq_len - 1] = next_id;
}
}
putchar('\n');
return 0;
}