| |
| |
| |
| |
| |
| |
| |
| |
| |
| #include <assert.h> |
| #include <inttypes.h> |
| #include <stdint.h> |
| #include <stdio.h> |
| #include <stdlib.h> |
| #include <string.h> |
|
|
| typedef uint64_t u64; |
| typedef uint32_t u32; |
| typedef int32_t i32; |
| typedef int64_t i64; |
| typedef uint8_t u8; |
|
|
| #define MAGIC_BIT1 0x31544942u |
|
|
| |
| typedef struct { |
| u32 vocab_size; |
| u32 d_model; |
| u32 n_layers; |
| u32 n_heads; |
| u32 d_ff; |
| u32 max_seq_len; |
| i64 logit_scale_M; |
|
|
| u32 head_dim; |
| u32 words_d; |
| u32 words_ff; |
| u32 words_head; |
| } Config; |
|
|
| |
| typedef struct { |
| u64 *weight_bits; |
| i32 *threshold; |
| u32 in_features; |
| u32 out_features; |
| u32 words_in; |
| } BitLinear; |
|
|
| typedef struct { |
| i32 *alibi_slopes; |
| BitLinear q, k, v, o; |
| } Attention; |
|
|
| typedef struct { |
| BitLinear gate, up, down; |
| } FFN; |
|
|
| typedef struct { |
| Attention attn; |
| FFN ffn; |
| } Layer; |
|
|
| typedef struct { |
| Config cfg; |
| u64 *embed_bits; |
| Layer *layers; |
| u64 *out_codebook_bits; |
| i64 *int_out_bias; |
| } Model; |
|
|
| |
| typedef struct { |
| u64 *x; |
| u64 *q_all; |
| u64 *k_all; |
| u64 *v_all; |
| u64 *a_bits; |
| u64 *f_bits; |
| u64 *g_bits; |
| u64 *u_bits; |
| u64 *h_bits; |
| i32 *scores; |
| } Buffers; |
|
|
| |
| static void must_read(void *ptr, size_t n, FILE *f, const char *what) { |
| if (fread(ptr, 1, n, f) != n) { |
| fprintf(stderr, "short read on %s\n", what); |
| exit(1); |
| } |
| } |
|
|
| static void read_bitlinear(FILE *f, BitLinear *bl, u32 in_features, u32 out_features) { |
| bl->in_features = in_features; |
| bl->out_features = out_features; |
| bl->words_in = (in_features + 63) / 64; |
| size_t wb = (size_t)out_features * bl->words_in * sizeof(u64); |
| bl->weight_bits = (u64 *)malloc(wb); |
| bl->threshold = (i32 *)malloc(out_features * sizeof(i32)); |
| must_read(bl->weight_bits, wb, f, "BitLinear.weight_bits"); |
| must_read(bl->threshold, out_features * sizeof(i32), f, "BitLinear.threshold"); |
| } |
|
|
| static void load_model(const char *path, Model *m) { |
| FILE *f = fopen(path, "rb"); |
| if (!f) { perror(path); exit(1); } |
|
|
| u32 header[8]; |
| must_read(header, sizeof(header), f, "header"); |
| if (header[0] != MAGIC_BIT1) { |
| fprintf(stderr, "bad magic 0x%08x (want 0x%08x)\n", header[0], MAGIC_BIT1); |
| exit(1); |
| } |
| if (header[1] != 1) { fprintf(stderr, "bad version %u\n", header[1]); exit(1); } |
|
|
| Config *c = &m->cfg; |
| c->vocab_size = header[2]; |
| c->d_model = header[3]; |
| c->n_layers = header[4]; |
| c->n_heads = header[5]; |
| c->d_ff = header[6]; |
| c->max_seq_len = header[7]; |
| must_read(&c->logit_scale_M, sizeof(i64), f, "logit_scale_M"); |
|
|
| c->head_dim = c->d_model / c->n_heads; |
| c->words_d = (c->d_model + 63) / 64; |
| c->words_ff = (c->d_ff + 63) / 64; |
| c->words_head = (c->head_dim + 63) / 64; |
|
|
| |
| size_t eb = (size_t)c->vocab_size * c->words_d * sizeof(u64); |
| m->embed_bits = (u64 *)malloc(eb); |
| must_read(m->embed_bits, eb, f, "embedding"); |
|
|
| |
| m->layers = (Layer *)calloc(c->n_layers, sizeof(Layer)); |
| for (u32 l = 0; l < c->n_layers; l++) { |
| Layer *ly = &m->layers[l]; |
| ly->attn.alibi_slopes = (i32 *)malloc(c->n_heads * sizeof(i32)); |
| must_read(ly->attn.alibi_slopes, c->n_heads * sizeof(i32), f, "alibi_slopes"); |
| read_bitlinear(f, &ly->attn.q, c->d_model, c->d_model); |
| read_bitlinear(f, &ly->attn.k, c->d_model, c->d_model); |
| read_bitlinear(f, &ly->attn.v, c->d_model, c->d_model); |
| read_bitlinear(f, &ly->attn.o, c->d_model, c->d_model); |
| read_bitlinear(f, &ly->ffn.gate, c->d_model, c->d_ff); |
| read_bitlinear(f, &ly->ffn.up, c->d_model, c->d_ff); |
| read_bitlinear(f, &ly->ffn.down, c->d_ff, c->d_model); |
| } |
|
|
| |
| m->out_codebook_bits = (u64 *)malloc(eb); |
| must_read(m->out_codebook_bits, eb, f, "out_codebook"); |
| m->int_out_bias = (i64 *)malloc(c->vocab_size * sizeof(i64)); |
| must_read(m->int_out_bias, c->vocab_size * sizeof(i64), f, "int_out_bias"); |
|
|
| |
| u8 tail; |
| size_t got = fread(&tail, 1, 1, f); |
| if (got != 0) { |
| fprintf(stderr, "warning: extra bytes after expected EOF\n"); |
| } |
| fclose(f); |
| } |
|
|
| |
| static Buffers alloc_buffers(const Config *c) { |
| Buffers b = {0}; |
| size_t wd = c->words_d; |
| size_t wf = c->words_ff; |
| b.x = (u64 *)calloc((size_t)c->max_seq_len * wd, sizeof(u64)); |
| b.q_all = (u64 *)calloc((size_t)c->max_seq_len * wd, sizeof(u64)); |
| b.k_all = (u64 *)calloc((size_t)c->max_seq_len * wd, sizeof(u64)); |
| b.v_all = (u64 *)calloc((size_t)c->max_seq_len * wd, sizeof(u64)); |
| b.a_bits = (u64 *)calloc(wd, sizeof(u64)); |
| b.f_bits = (u64 *)calloc(wd, sizeof(u64)); |
| b.g_bits = (u64 *)calloc(wf, sizeof(u64)); |
| b.u_bits = (u64 *)calloc(wf, sizeof(u64)); |
| b.h_bits = (u64 *)calloc(wf, sizeof(u64)); |
| b.scores = (i32 *)calloc((size_t)c->n_heads * c->max_seq_len, sizeof(i32)); |
| return b; |
| } |
|
|
| |
|
|
| |
| |
| |
| static inline i32 bipolar_dot(const u64 *a, const u64 *b, u32 words, u32 in_features) { |
| i64 agree = 0; |
| for (u32 w = 0; w < words; w++) { |
| u64 matches = ~(a[w] ^ b[w]); |
| agree += __builtin_popcountll(matches); |
| } |
| |
| |
| u32 pad = (words * 64) - in_features; |
| agree -= pad; |
| return (i32)(2 * agree - in_features); |
| } |
|
|
| |
| |
| |
| |
| static void bitlinear_forward(const BitLinear *bl, const u64 *x_bits, u64 *out_bits) { |
| u32 words_out = (bl->out_features + 63) / 64; |
| memset(out_bits, 0, words_out * sizeof(u64)); |
| for (u32 i = 0; i < bl->out_features; i++) { |
| const u64 *w_row = bl->weight_bits + (size_t)i * bl->words_in; |
| i32 y = bipolar_dot(w_row, x_bits, bl->words_in, bl->in_features); |
| if (y >= bl->threshold[i]) { |
| out_bits[i / 64] |= ((u64)1) << (i % 64); |
| } |
| } |
| } |
|
|
| |
| |
| |
| static inline void extract_head(const u64 *x_bits, u32 head_dim, u32 h, u64 *head_bits) { |
| u32 start_bit = h * head_dim; |
| u32 words = (head_dim + 63) / 64; |
| for (u32 w = 0; w < words; w++) { |
| |
| u32 bit_start = start_bit + w * 64; |
| u32 lo_word = bit_start / 64; |
| u32 shift = bit_start % 64; |
| u64 v = x_bits[lo_word] >> shift; |
| if (shift && (lo_word + 1) * 64 < start_bit + head_dim) { |
| v |= x_bits[lo_word + 1] << (64 - shift); |
| } |
| |
| u32 remaining = (w + 1) * 64 <= head_dim ? 64 : (head_dim - w * 64); |
| if (remaining < 64) { |
| v &= (remaining == 64) ? ~(u64)0 : (((u64)1 << remaining) - 1); |
| } |
| head_bits[w] = v; |
| } |
| } |
|
|
| |
| |
| static inline void majority3(const u64 *a, const u64 *b, const u64 *c, u64 *out, u32 words) { |
| for (u32 w = 0; w < words; w++) { |
| out[w] = (a[w] & b[w]) | (a[w] & c[w]) | (b[w] & c[w]); |
| } |
| } |
|
|
| |
| |
| |
| |
|
|
| static u32 argmax_next_token(const Model *m, Buffers *b, const u32 *ids, u32 T) { |
| const Config *c = &m->cfg; |
| u32 wd = c->words_d; |
| u32 wf = c->words_ff; |
|
|
| |
| for (u32 t = 0; t < T; t++) { |
| memcpy(b->x + (size_t)t * wd, |
| m->embed_bits + (size_t)ids[t] * wd, |
| wd * sizeof(u64)); |
| } |
|
|
| |
| u32 wh = c->words_head; |
| u64 q_head[8]; |
| u64 k_head[8]; |
| (void)q_head; (void)k_head; |
|
|
| |
| for (u32 li = 0; li < c->n_layers; li++) { |
| const Layer *ly = &m->layers[li]; |
|
|
| |
| for (u32 t = 0; t < T; t++) { |
| const u64 *xt = b->x + (size_t)t * wd; |
| bitlinear_forward(&ly->attn.q, xt, b->q_all + (size_t)t * wd); |
| bitlinear_forward(&ly->attn.k, xt, b->k_all + (size_t)t * wd); |
| bitlinear_forward(&ly->attn.v, xt, b->v_all + (size_t)t * wd); |
| } |
|
|
| |
| for (u32 t = 0; t < T; t++) { |
| const u64 *q_t = b->q_all + (size_t)t * wd; |
|
|
| |
| memset(b->a_bits, 0, wd * sizeof(u64)); |
|
|
| |
| for (u32 h = 0; h < c->n_heads; h++) { |
| |
| extract_head(q_t, c->head_dim, h, q_head); |
|
|
| |
| i32 best_score = INT32_MIN; |
| u32 best_j = 0; |
| for (u32 j = 0; j <= t; j++) { |
| const u64 *k_j = b->k_all + (size_t)j * wd; |
| extract_head(k_j, c->head_dim, h, k_head); |
| i32 s = bipolar_dot(q_head, k_head, wh, c->head_dim); |
| i32 dist = (i32)t - (i32)j; |
| if (dist < 0) dist = -dist; |
| s -= ly->attn.alibi_slopes[h] * dist; |
| if (s > best_score) { |
| best_score = s; |
| best_j = j; |
| } |
| } |
|
|
| |
| const u64 *v_bits = b->v_all + (size_t)best_j * wd; |
| for (u32 bit = 0; bit < c->head_dim; bit++) { |
| u32 src_bit = h * c->head_dim + bit; |
| u64 v = (v_bits[src_bit / 64] >> (src_bit % 64)) & 1ULL; |
| u32 dst_bit = h * c->head_dim + bit; |
| b->a_bits[dst_bit / 64] |= v << (dst_bit % 64); |
| } |
| } |
|
|
| |
| u64 a_tmp[16]; |
| bitlinear_forward(&ly->attn.o, b->a_bits, a_tmp); |
|
|
| |
| const u64 *x_t = b->x + (size_t)t * wd; |
| bitlinear_forward(&ly->ffn.gate, x_t, b->g_bits); |
| bitlinear_forward(&ly->ffn.up, x_t, b->u_bits); |
| for (u32 w = 0; w < wf; w++) { |
| b->h_bits[w] = ~(b->g_bits[w] ^ b->u_bits[w]); |
| } |
| bitlinear_forward(&ly->ffn.down, b->h_bits, b->f_bits); |
|
|
| |
| u64 new_x[16]; |
| majority3(x_t, a_tmp, b->f_bits, new_x, wd); |
| memcpy(b->x + (size_t)t * wd, new_x, wd * sizeof(u64)); |
| } |
| } |
|
|
| |
| const u64 *x_last = b->x + (size_t)(T - 1) * wd; |
| i64 best_logit = INT64_MIN; |
| u32 best_v = 0; |
| for (u32 v = 0; v < c->vocab_size; v++) { |
| const u64 *vec = m->out_codebook_bits + (size_t)v * wd; |
| i32 dot = bipolar_dot(vec, x_last, wd, c->d_model); |
| i64 logit = (i64)dot * c->logit_scale_M + m->int_out_bias[v]; |
| if (logit > best_logit) { |
| best_logit = logit; |
| best_v = v; |
| } |
| } |
| return best_v; |
| } |
|
|
| |
| int main(int argc, char **argv) { |
| if (argc < 4) { |
| fprintf(stderr, "usage: %s <weights.bin> \"<prompt>\" <num_new_tokens>\n", argv[0]); |
| return 2; |
| } |
| const char *bin_path = argv[1]; |
| const char *prompt = argv[2]; |
| u32 n_new = (u32)atoi(argv[3]); |
|
|
| Model m = {0}; |
| load_model(bin_path, &m); |
| fprintf(stderr, |
| "loaded v18 bin: vocab=%u d_model=%u n_layers=%u n_heads=%u d_ff=%u T_max=%u M=%" PRId64 "\n", |
| m.cfg.vocab_size, m.cfg.d_model, m.cfg.n_layers, m.cfg.n_heads, m.cfg.d_ff, |
| m.cfg.max_seq_len, m.cfg.logit_scale_M); |
|
|
| Buffers b = alloc_buffers(&m.cfg); |
|
|
| |
| u32 prompt_len = (u32)strlen(prompt); |
| if (prompt_len == 0) { fprintf(stderr, "empty prompt\n"); return 2; } |
| if (prompt_len > m.cfg.max_seq_len) prompt_len = m.cfg.max_seq_len; |
| u32 *ids = (u32 *)malloc((m.cfg.max_seq_len + n_new) * sizeof(u32)); |
| for (u32 i = 0; i < prompt_len; i++) { |
| u8 c = (u8)prompt[i]; |
| ids[i] = c < m.cfg.vocab_size ? c : 32; |
| } |
|
|
| |
| fwrite(prompt, 1, prompt_len, stdout); |
|
|
| u32 T = prompt_len; |
| for (u32 step = 0; step < n_new; step++) { |
| u32 next_id = argmax_next_token(&m, &b, ids, T); |
| putchar((int)next_id); |
| fflush(stdout); |
| if (T < m.cfg.max_seq_len) { |
| ids[T] = next_id; |
| T++; |
| } else { |
| |
| memmove(ids, ids + 1, (m.cfg.max_seq_len - 1) * sizeof(u32)); |
| ids[m.cfg.max_seq_len - 1] = next_id; |
| } |
| } |
| putchar('\n'); |
| return 0; |
| } |
|
|