bitnet-1bitllm / vm_backup /code /infer_omp.c
hidude562's picture
1bitllm code (checkpoints to follow)
4754707 verified
/* Step 3: KV cache + AVX-512 SIMD + OpenMP.
*
* Build: gcc -O3 -march=native -fopenmp -o infer_omp infer_omp.c
*
* Parallelizes two independent loops:
* - BitLinear rows (each output neuron is an independent dot product).
* - Attention heads (independent across h).
*
* The rest of the forward step stays the same. Speedup scales ~linearly with
* cores until memory bandwidth saturates (model is ~740 KB, fits in L2 per
* core with SMT — expect 6-12x on a 16-thread CPU).
*/
#include <assert.h>
#include <immintrin.h>
#include <inttypes.h>
#include <omp.h>
#include <stdint.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <time.h>
typedef uint64_t u64;
typedef uint32_t u32;
typedef int32_t i32;
typedef int64_t i64;
typedef uint8_t u8;
#define MAGIC_BIT1 0x31544942u
typedef struct {
u32 vocab_size, d_model, n_layers, n_heads, d_ff, max_seq_len;
i64 logit_scale_M;
u32 head_dim, words_d, words_ff, words_head;
} Config;
typedef struct {
u64 *weight_bits;
i32 *threshold;
u32 in_features, out_features, words_in;
} BitLinear;
typedef struct { i32 *alibi_slopes; BitLinear q, k, v, o; } Attention;
typedef struct { BitLinear gate, up, down; } FFN;
typedef struct { Attention attn; FFN ffn; u64 *k_cache; u64 *v_cache; } Layer;
typedef struct {
Config cfg;
u64 *embed_bits;
Layer *layers;
u64 *out_codebook_bits;
i64 *int_out_bias;
} Model;
static inline i32 bipolar_dot_d256(const u64 *a, const u64 *b) {
__m256i va = _mm256_loadu_si256((const __m256i *)a);
__m256i vb = _mm256_loadu_si256((const __m256i *)b);
__m256i vxor = _mm256_xor_si256(va, vb);
__m256i vpop = _mm256_popcnt_epi64(vxor);
u64 d0 = _mm256_extract_epi64(vpop, 0);
u64 d1 = _mm256_extract_epi64(vpop, 1);
u64 d2 = _mm256_extract_epi64(vpop, 2);
u64 d3 = _mm256_extract_epi64(vpop, 3);
return (i32)(256 - 2 * (i64)(d0 + d1 + d2 + d3));
}
static inline i32 bipolar_dot_d512(const u64 *a, const u64 *b) {
__m512i va = _mm512_loadu_si512((const void *)a);
__m512i vb = _mm512_loadu_si512((const void *)b);
return (i32)(512 - 2 * (i64)_mm512_reduce_add_epi64(_mm512_popcnt_epi64(_mm512_xor_si512(va, vb))));
}
static inline i32 bipolar_dot_h32(u64 a, u64 b) {
return 32 - 2 * (i32)__builtin_popcountll((a ^ b) & 0xFFFFFFFFULL);
}
static inline i32 bipolar_dot_h64(const u64 *a, const u64 *b) {
return 64 - 2 * (i32)__builtin_popcountll(a[0] ^ b[0]);
}
/* General AVX-512 path — handles any size. Processes 8 u64 at a time via ZMM+VPOPCNTQ,
falls back to 4-at-a-time (YMM) and scalar for tails. */
static inline i32 bipolar_dot_avx512_gen(const u64 *a, const u64 *b, u32 words, u32 in_features) {
i64 d = 0;
u32 w = 0;
__m512i acc = _mm512_setzero_si512();
for (; w + 8 <= words; w += 8) {
__m512i va = _mm512_loadu_si512((const void *)(a + w));
__m512i vb = _mm512_loadu_si512((const void *)(b + w));
acc = _mm512_add_epi64(acc, _mm512_popcnt_epi64(_mm512_xor_si512(va, vb)));
}
d += _mm512_reduce_add_epi64(acc);
if (w + 4 <= words) {
__m256i va = _mm256_loadu_si256((const __m256i *)(a + w));
__m256i vb = _mm256_loadu_si256((const __m256i *)(b + w));
__m256i p = _mm256_popcnt_epi64(_mm256_xor_si256(va, vb));
d += _mm256_extract_epi64(p, 0) + _mm256_extract_epi64(p, 1)
+ _mm256_extract_epi64(p, 2) + _mm256_extract_epi64(p, 3);
w += 4;
}
for (; w < words; w++) d += __builtin_popcountll(a[w] ^ b[w]);
return (i32)(in_features - 2 * d);
}
static inline i32 bipolar_dot_gen(const u64 *a, const u64 *b, u32 words, u32 in_features) {
/* Legacy scalar. Kept only for reference / head_dim=32 mask case. */
i64 d = 0;
for (u32 w = 0; w < words; w++) d += __builtin_popcountll(a[w] ^ b[w]);
return (i32)(in_features - 2 * d);
}
static void bitlinear_forward(const BitLinear *bl, const u64 *x_bits, u64 *out_bits) {
u32 words_out = (bl->out_features + 63) / 64;
memset(out_bits, 0, words_out * sizeof(u64));
u32 N = bl->out_features;
/* Thread-local bit accumulators, merged with atomic OR at the end. */
if (bl->in_features == 256) {
#pragma omp parallel
{
u64 local_out[16] = {0};
#pragma omp for nowait schedule(static)
for (u32 i = 0; i < N; i++) {
const u64 *w_row = bl->weight_bits + (size_t)i * 4;
i32 y = bipolar_dot_d256(w_row, x_bits);
if (y >= bl->threshold[i]) local_out[i / 64] |= ((u64)1) << (i % 64);
}
for (u32 w = 0; w < words_out; w++) {
#pragma omp atomic
out_bits[w] |= local_out[w];
}
}
} else if (bl->in_features == 512) {
#pragma omp parallel
{
u64 local_out[16] = {0};
#pragma omp for nowait schedule(static)
for (u32 i = 0; i < N; i++) {
const u64 *w_row = bl->weight_bits + (size_t)i * 8;
i32 y = bipolar_dot_d512(w_row, x_bits);
if (y >= bl->threshold[i]) local_out[i / 64] |= ((u64)1) << (i % 64);
}
for (u32 w = 0; w < words_out; w++) {
#pragma omp atomic
out_bits[w] |= local_out[w];
}
}
} else {
/* General AVX-512 path for arbitrary sizes (d_model=768, d_ff=1280, etc.) */
#pragma omp parallel
{
u64 local_out[32] = {0};
#pragma omp for nowait schedule(static)
for (u32 i = 0; i < N; i++) {
const u64 *w_row = bl->weight_bits + (size_t)i * bl->words_in;
i32 y = bipolar_dot_avx512_gen(w_row, x_bits, bl->words_in, bl->in_features);
if (y >= bl->threshold[i]) local_out[i / 64] |= ((u64)1) << (i % 64);
}
for (u32 w = 0; w < words_out; w++) {
#pragma omp atomic
out_bits[w] |= local_out[w];
}
}
}
}
static inline void extract_head(const u64 *x_bits, u32 head_dim, u32 h, u64 *head_bits) {
u32 start_bit = h * head_dim;
u32 words = (head_dim + 63) / 64;
for (u32 w = 0; w < words; w++) {
u32 bit_start = start_bit + w * 64;
u32 lo_word = bit_start / 64;
u32 shift = bit_start % 64;
u64 v = x_bits[lo_word] >> shift;
if (shift && (lo_word + 1) * 64 < start_bit + head_dim)
v |= x_bits[lo_word + 1] << (64 - shift);
u32 remaining = (w + 1) * 64 <= head_dim ? 64 : (head_dim - w * 64);
if (remaining < 64) v &= (((u64)1 << remaining) - 1);
head_bits[w] = v;
}
}
static inline void majority3(const u64 *a, const u64 *b, const u64 *c, u64 *out, u32 words) {
for (u32 w = 0; w < words; w++) out[w] = (a[w] & b[w]) | (a[w] & c[w]) | (b[w] & c[w]);
}
static void must_read(void *ptr, size_t n, FILE *f, const char *what) {
if (fread(ptr, 1, n, f) != n) { fprintf(stderr, "short read %s\n", what); exit(1); }
}
static void read_bitlinear(FILE *f, BitLinear *bl, u32 in_features, u32 out_features) {
bl->in_features = in_features; bl->out_features = out_features;
bl->words_in = (in_features + 63) / 64;
size_t wb = (size_t)out_features * bl->words_in * sizeof(u64);
bl->weight_bits = (u64 *)aligned_alloc(64, (wb + 63) & ~63UL);
bl->threshold = (i32 *)malloc(out_features * sizeof(i32));
must_read(bl->weight_bits, wb, f, "w"); must_read(bl->threshold, out_features * sizeof(i32), f, "thr");
}
static void load_model(const char *path, Model *m) {
FILE *f = fopen(path, "rb");
if (!f) { perror(path); exit(1); }
u32 header[8];
must_read(header, sizeof(header), f, "hdr");
if (header[0] != MAGIC_BIT1) { fprintf(stderr, "bad magic\n"); exit(1); }
Config *c = &m->cfg;
c->vocab_size=header[2]; c->d_model=header[3]; c->n_layers=header[4];
c->n_heads=header[5]; c->d_ff=header[6]; c->max_seq_len=header[7];
must_read(&c->logit_scale_M, sizeof(i64), f, "M");
c->head_dim = c->d_model / c->n_heads;
c->words_d = (c->d_model+63)/64; c->words_ff = (c->d_ff+63)/64; c->words_head = (c->head_dim+63)/64;
size_t eb = (size_t)c->vocab_size * c->words_d * sizeof(u64);
m->embed_bits = (u64 *)aligned_alloc(64, (eb + 63) & ~63UL);
must_read(m->embed_bits, eb, f, "embed");
m->layers = (Layer *)calloc(c->n_layers, sizeof(Layer));
for (u32 l = 0; l < c->n_layers; l++) {
Layer *ly = &m->layers[l];
ly->attn.alibi_slopes = (i32 *)malloc(c->n_heads * sizeof(i32));
must_read(ly->attn.alibi_slopes, c->n_heads * sizeof(i32), f, "alibi");
read_bitlinear(f, &ly->attn.q, c->d_model, c->d_model);
read_bitlinear(f, &ly->attn.k, c->d_model, c->d_model);
read_bitlinear(f, &ly->attn.v, c->d_model, c->d_model);
read_bitlinear(f, &ly->attn.o, c->d_model, c->d_model);
read_bitlinear(f, &ly->ffn.gate, c->d_model, c->d_ff);
read_bitlinear(f, &ly->ffn.up, c->d_model, c->d_ff);
read_bitlinear(f, &ly->ffn.down, c->d_ff, c->d_model);
ly->k_cache = (u64 *)calloc((size_t)c->max_seq_len * c->words_d, sizeof(u64));
ly->v_cache = (u64 *)calloc((size_t)c->max_seq_len * c->words_d, sizeof(u64));
}
m->out_codebook_bits = (u64 *)aligned_alloc(64, (eb + 63) & ~63UL);
must_read(m->out_codebook_bits, eb, f, "cb");
m->int_out_bias = (i64 *)malloc(c->vocab_size * sizeof(i64));
must_read(m->int_out_bias, c->vocab_size * sizeof(i64), f, "bias");
fclose(f);
}
static u32 step_token(Model *m, u32 token_id, u32 t) {
const Config *c = &m->cfg;
u32 wd = c->words_d, wf = c->words_ff;
/* Stack allocation: SHARED by default inside enclosed omp parallel regions
* (OpenMP rule for automatic vars declared outside the region), which is
* what we need. __thread would give each worker its own copy and drop
* writes on return to main. */
/* Sized for d_model ≤ 2048 (=32 u64) and d_ff ≤ 4096 (=64 u64) */
u64 x[32] __attribute__((aligned(64)));
u64 q[32], k[32], v[32], a_bits[32], o_bits[32];
u64 g_bits[64], u_bits[64], h_bits[64], f_bits[32], new_x[32];
u64 q_head[8];
memcpy(x, m->embed_bits + (size_t)token_id * wd, wd * sizeof(u64));
for (u32 li = 0; li < c->n_layers; li++) {
Layer *ly = &m->layers[li];
bitlinear_forward(&ly->attn.q, x, q);
bitlinear_forward(&ly->attn.k, x, k);
bitlinear_forward(&ly->attn.v, x, v);
memcpy(ly->k_cache + (size_t)t * wd, k, wd * sizeof(u64));
memcpy(ly->v_cache + (size_t)t * wd, v, wd * sizeof(u64));
memset(a_bits, 0, wd * sizeof(u64));
/* Parallelize across heads. */
#pragma omp parallel for schedule(static)
for (u32 h = 0; h < c->n_heads; h++) {
u64 q_h[8]; u64 k_h[8];
extract_head(q, c->head_dim, h, q_h);
i32 best_score = INT32_MIN;
u32 best_j = 0;
if (c->head_dim == 32) {
for (u32 j = 0; j <= t; j++) {
const u64 *k_j = ly->k_cache + (size_t)j * wd;
u64 k_word = k_j[(h * c->head_dim) / 64] >> ((h * c->head_dim) % 64);
k_word &= 0xFFFFFFFFULL;
i32 s = bipolar_dot_h32(q_h[0], k_word);
i32 d = (i32)t - (i32)j; if (d < 0) d = -d;
s -= ly->attn.alibi_slopes[h] * d;
if (s > best_score) { best_score = s; best_j = j; }
}
} else {
for (u32 j = 0; j <= t; j++) {
const u64 *k_j = ly->k_cache + (size_t)j * wd;
extract_head(k_j, c->head_dim, h, k_h);
i32 s = bipolar_dot_avx512_gen(q_h, k_h, c->words_head, c->head_dim);
i32 d = (i32)t - (i32)j; if (d < 0) d = -d;
s -= ly->attn.alibi_slopes[h] * d;
if (s > best_score) { best_score = s; best_j = j; }
}
}
/* Gather V slice for this head */
const u64 *v_bits = ly->v_cache + (size_t)best_j * wd;
u64 head_acc[8] = {0};
for (u32 bit = 0; bit < c->head_dim; bit++) {
u32 src_bit = h * c->head_dim + bit;
u64 vv = (v_bits[src_bit / 64] >> (src_bit % 64)) & 1ULL;
head_acc[(h * c->head_dim + bit) / 64 - (h * c->head_dim) / 64] |= vv << ((h * c->head_dim + bit) % 64);
}
/* Merge this head's bits back into a_bits (OR, thread-safe for non-overlapping heads on 64-bit boundary) */
u32 base_word = (h * c->head_dim) / 64;
u32 end_word = (((h + 1) * c->head_dim) + 63) / 64;
/* Heads at 32-bit widths may share words with neighbors; use atomic OR. */
for (u32 w = base_word; w < end_word; w++) {
#pragma omp atomic
a_bits[w] |= head_acc[w - base_word];
}
}
bitlinear_forward(&ly->attn.o, a_bits, o_bits);
bitlinear_forward(&ly->ffn.gate, x, g_bits);
bitlinear_forward(&ly->ffn.up, x, u_bits);
for (u32 w = 0; w < wf; w++) h_bits[w] = ~(g_bits[w] ^ u_bits[w]);
bitlinear_forward(&ly->ffn.down, h_bits, f_bits);
majority3(x, o_bits, f_bits, new_x, wd);
memcpy(x, new_x, wd * sizeof(u64));
}
i64 best_logit = INT64_MIN;
u32 best_v = 0;
for (u32 vid = 0; vid < c->vocab_size; vid++) {
const u64 *vec = m->out_codebook_bits + (size_t)vid * wd;
i32 dot = (c->d_model == 256) ? bipolar_dot_d256(vec, x)
: bipolar_dot_avx512_gen(vec, x, wd, c->d_model);
i64 logit = (i64)dot * c->logit_scale_M + m->int_out_bias[vid];
if (logit > best_logit) { best_logit = logit; best_v = vid; }
}
return best_v;
}
static double now_ms(void) {
struct timespec ts; clock_gettime(CLOCK_MONOTONIC, &ts);
return ts.tv_sec * 1000.0 + ts.tv_nsec * 1e-6;
}
int main(int argc, char **argv) {
if (argc < 4) { fprintf(stderr, "usage: %s <bin> \"<prompt>\" <n_new> [threads]\n", argv[0]); return 2; }
const char *bin = argv[1];
const char *prompt = argv[2];
u32 n_new = (u32)atoi(argv[3]);
int threads = argc > 4 ? atoi(argv[4]) : 0;
if (threads > 0) omp_set_num_threads(threads);
Model m = {0}; load_model(bin, &m);
fprintf(stderr,
"loaded: vocab=%u d=%u L=%u H=%u ff=%u Tmax=%u threads=%d\n",
m.cfg.vocab_size, m.cfg.d_model, m.cfg.n_layers, m.cfg.n_heads, m.cfg.d_ff,
m.cfg.max_seq_len, omp_get_max_threads());
u32 prompt_len = (u32)strlen(prompt);
if (prompt_len == 0 || prompt_len + n_new > m.cfg.max_seq_len) return 2;
double t0 = now_ms();
u32 next_id = 0;
for (u32 t = 0; t < prompt_len; t++) next_id = step_token(&m, (u32)(u8)prompt[t], t);
double t1 = now_ms();
fwrite(prompt, 1, prompt_len, stdout);
double t2 = now_ms();
for (u32 s = 0; s < n_new; s++) {
putchar((int)next_id); fflush(stdout);
u32 pos = prompt_len + s;
if (pos >= m.cfg.max_seq_len) break;
next_id = step_token(&m, next_id, pos);
}
double t3 = now_ms();
putchar('\n');
fprintf(stderr,
"prefill: %u tok in %.2f ms (%.0f tok/s)\n"
"generate: %u tok in %.2f ms (%.0f tok/s)\n",
prompt_len, t1 - t0, prompt_len * 1000.0 / (t1 - t0),
n_new, t3 - t2, n_new * 1000.0 / (t3 - t2));
return 0;
}