unary-quantization-research / pure_unary_engine.c
OpenTransformer's picture
Add files using upload-large-folder tool
19ed98b verified
/*
* PURE UNARY TRANSFORMER ENGINE
*
* ALL matrix multiplications use base-1 arithmetic:
* - Weights: unary encoded (sign + N magnitude planes)
* - Activations: unary encoded (sign + M magnitude planes)
* - Matmul = bitwise AND + popcount across plane pairs
* - Float only used for: RMSNorm, SiLU, Softmax, rescale, residual add
* - These are all O(dim) not O(dim²), so don't dominate
*
* (c) 2026 OpenTransformers Ltd / Scott Bisset
*/
#include <immintrin.h>
#include <omp.h>
#include <stdint.h>
#include <stdlib.h>
#include <string.h>
#include <math.h>
#include <stdio.h>
#include <time.h>
#define MAX_SEQ 4096
#define RMS_EPS 1e-6f
/* ============================================================
* Unary vector: a quantized 1D activation or intermediate
* ============================================================ */
typedef struct {
uint64_t *sign; /* [chunks] */
uint64_t *planes; /* [n_planes][chunks] */
float scale;
int dim;
int chunks;
int n_planes;
} UnaryVec;
/* ============================================================
* Config
* ============================================================ */
typedef struct {
int hidden;
int inter;
int n_heads;
int n_kv_heads;
int head_dim;
int n_layers;
int vocab;
float rope_theta;
int tie_embeddings;
int w_planes; /* weight quantization planes */
int a_planes; /* activation quantization planes */
} Config;
/* Unary weight matrix */
typedef struct {
uint64_t *sign_bits;
uint64_t *mag_planes;
float *scales;
int out_dim;
int in_dim;
int n_planes;
int chunks; /* = (in_dim + 63) / 64 */
} UnaryWeight;
/* Transformer layer */
typedef struct {
UnaryWeight q_proj, k_proj, v_proj, o_proj;
UnaryWeight gate_proj, up_proj, down_proj;
float *input_norm;
float *post_norm;
float *q_norm, *k_norm;
} Layer;
/* Full model */
typedef struct {
Config cfg;
uint16_t *embed;
Layer *layers;
float *final_norm;
/* KV cache (float - only O(seq × heads × dim) not O(dim²)) */
float *k_cache;
float *v_cache;
/* Scratch - float buffers for non-matmul ops */
float *hidden; /* residual stream */
float *normed; /* after RMSNorm, before quantization */
float *q_float;
float *k_float;
float *v_float;
float *attn_out;
float *gate_float;
float *up_float;
float *mlp_act; /* gate*up result before quantization */
float *logits;
float *attn_scores;
/* Scratch - unary vectors for matmul inputs */
UnaryVec uv_normed;
UnaryVec uv_mlp_in;
UnaryVec uv_mlp_act; /* for down_proj input */
/* Output integer accumulators (avoid malloc per call) */
int *acc_buf;
} Model;
/* ============================================================
* ACTIVATION QUANTIZATION: float -> unary
* Runs per-vector: one scale for entire vector
* O(dim) operation, not in the hot path
* ============================================================ */
static void quantize_to_unary(
const float *x, int dim, int n_planes,
uint64_t *sign_out, uint64_t *planes_out, float *scale_out
) {
int chunks = (dim + 63) / 64;
/* Find absmax */
float amax = 0.0f;
for (int i = 0; i < dim; i++) {
float a = fabsf(x[i]);
if (a > amax) amax = a;
}
if (amax == 0.0f) amax = 1.0f;
*scale_out = amax / n_planes;
/* Clear output */
memset(sign_out, 0, chunks * sizeof(uint64_t));
memset(planes_out, 0, (size_t)n_planes * chunks * sizeof(uint64_t));
/* Quantize element by element */
float inv_scale = n_planes / amax;
for (int i = 0; i < dim; i++) {
int chunk = i / 64;
int bit = i % 64;
uint64_t mask = 1ULL << bit;
/* Sign */
if (x[i] < 0.0f)
sign_out[chunk] |= mask;
/* Magnitude: thermometer encode */
int mag = (int)(fabsf(x[i]) * inv_scale + 0.5f);
if (mag > n_planes) mag = n_planes;
for (int p = 0; p < mag; p++)
planes_out[(size_t)p * chunks + chunk] |= mask;
}
}
/* ============================================================
* PURE UNARY MATVEC: y = W @ x
*
* Both W and x are unary encoded.
* Inner loop is purely: AND + popcount
* Float multiply happens ONCE per output element (rescale)
* ============================================================ */
static void pure_unary_matvec(
const UnaryWeight *W,
const uint64_t *x_sign, const uint64_t *x_planes,
float x_scale, int x_n_planes,
float *y_out, /* float output for non-matmul ops */
int *acc_buf /* scratch for integer accumulators */
) {
int out_dim = W->out_dim;
int chunks = W->chunks;
int wp = W->n_planes;
int xp = x_n_planes;
#pragma omp parallel for schedule(dynamic, 32)
for (int i = 0; i < out_dim; i++) {
const uint64_t *w_sign_row = W->sign_bits + (size_t)i * chunks;
/* Precompute same_sign mask for this row vs input */
/* same_sign[c] = ~(w_sign[c] ^ x_sign[c]) */
/* We compute this per-chunk inside the loop to avoid allocation */
long long acc = 0;
for (int c = 0; c < chunks; c++) {
uint64_t ws = w_sign_row[c];
uint64_t xs = x_sign[c];
uint64_t same = ~(ws ^ xs); /* bits where signs agree */
uint64_t diff = ws ^ xs; /* bits where signs differ */
for (int p = 0; p < wp; p++) {
uint64_t w_mag = W->mag_planes[((size_t)p * out_dim + i) * chunks + c];
for (int q = 0; q < xp; q++) {
uint64_t x_mag = x_planes[(size_t)q * chunks + c];
uint64_t active = w_mag & x_mag;
/* Count positive and negative contributions */
uint64_t pos = active & same;
uint64_t neg = active & diff;
acc += __builtin_popcountll(pos) - __builtin_popcountll(neg);
}
}
}
/* Single float rescale per output element */
y_out[i] = (float)acc * W->scales[i] * x_scale;
}
}
/* ============================================================
* FP16 embedding lookup (only used for embed/lm_head)
* ============================================================ */
static void embed_token(const uint16_t *embed, int token_id, float *out, int hidden) {
const uint16_t *row = embed + (size_t)token_id * hidden;
int i;
for (i = 0; i + 16 <= hidden; i += 16) {
__m256i h = _mm256_loadu_si256((__m256i*)(row + i));
__m512 fv = _mm512_cvtph_ps(h);
_mm512_storeu_ps(out + i, fv);
}
for (; i < hidden; i++) {
__m128i hv = _mm_set1_epi16(row[i]);
__m128 fv = _mm_cvtph_ps(hv);
_mm_store_ss(out + i, fv);
}
}
/* FP16 matvec for lm_head (vocab is huge, keep as FP16) */
static void fp16_matvec(const uint16_t *w, const float *x, float *y, int out_dim, int in_dim) {
#pragma omp parallel for schedule(dynamic, 256)
for (int i = 0; i < out_dim; i++) {
__m512 acc = _mm512_setzero_ps();
int j;
for (j = 0; j + 16 <= in_dim; j += 16) {
__m256i h = _mm256_loadu_si256((__m256i*)(w + (size_t)i * in_dim + j));
__m512 wv = _mm512_cvtph_ps(h);
__m512 xv = _mm512_loadu_ps(x + j);
acc = _mm512_fmadd_ps(wv, xv, acc);
}
float sum = _mm512_reduce_add_ps(acc);
for (; j < in_dim; j++) {
__m128i hv = _mm_set1_epi16(w[(size_t)i * in_dim + j]);
__m128 fv = _mm_cvtph_ps(hv);
float wf;
_mm_store_ss(&wf, fv);
sum += wf * x[j];
}
y[i] = sum;
}
}
/* ============================================================
* O(dim) operations - float is fine here, not the bottleneck
* ============================================================ */
static void rmsnorm(const float *x, const float *w, float *y, int dim) {
float ss = 0.0f;
for (int i = 0; i < dim; i++) ss += x[i] * x[i];
float rms = 1.0f / sqrtf(ss / dim + RMS_EPS);
for (int i = 0; i < dim; i++) y[i] = x[i] * rms * w[i];
}
static void rmsnorm_head(const float *x, const float *w, float *y, int dim) {
/* RMSNorm for a single attention head */
rmsnorm(x, w, y, dim);
}
static void silu_mul(const float *gate, const float *up, float *out, int n) {
for (int i = 0; i < n; i++)
out[i] = (gate[i] / (1.0f + expf(-gate[i]))) * up[i];
}
static void vec_add(float *y, const float *x, int n) {
for (int i = 0; i < n; i++) y[i] += x[i];
}
static void apply_rope(float *vec, int pos, int dim, float theta) {
for (int i = 0; i < dim; i += 2) {
float freq = 1.0f / powf(theta, (float)i / dim);
float angle = pos * freq;
float c = cosf(angle), s = sinf(angle);
float v0 = vec[i], v1 = vec[i + 1];
vec[i] = v0 * c - v1 * s;
vec[i + 1] = v0 * s + v1 * c;
}
}
static void softmax(float *x, int n) {
float mx = x[0];
for (int i = 1; i < n; i++) if (x[i] > mx) mx = x[i];
float sum = 0.0f;
for (int i = 0; i < n; i++) { x[i] = expf(x[i] - mx); sum += x[i]; }
float inv = 1.0f / sum;
for (int i = 0; i < n; i++) x[i] *= inv;
}
/* KV cache access */
static float* kv_ptr(float *cache, const Config *c, int layer, int pos, int kv_head) {
return cache + ((size_t)layer * MAX_SEQ * c->n_kv_heads +
(size_t)pos * c->n_kv_heads + kv_head) * c->head_dim;
}
/* ============================================================
* ALLOC unary vector scratch
* ============================================================ */
static void uv_alloc(UnaryVec *uv, int dim, int n_planes) {
int chunks = (dim + 63) / 64;
uv->dim = dim;
uv->chunks = chunks;
uv->n_planes = n_planes;
uv->sign = (uint64_t *)aligned_alloc(64, chunks * sizeof(uint64_t));
uv->planes = (uint64_t *)aligned_alloc(64, (size_t)n_planes * chunks * sizeof(uint64_t));
uv->scale = 0.0f;
}
/* ============================================================
* ATTENTION (using pure unary for projections)
* ============================================================ */
static void attention(Model *m, int layer_idx, int pos) {
Config *c = &m->cfg;
Layer *layer = &m->layers[layer_idx];
int heads_per_kv = c->n_heads / c->n_kv_heads;
/* Quantize normed hidden to unary */
quantize_to_unary(m->normed, c->hidden, c->a_planes,
m->uv_normed.sign, m->uv_normed.planes, &m->uv_normed.scale);
/* Q, K, V projections - PURE UNARY */
pure_unary_matvec(&layer->q_proj,
m->uv_normed.sign, m->uv_normed.planes, m->uv_normed.scale, c->a_planes,
m->q_float, m->acc_buf);
pure_unary_matvec(&layer->k_proj,
m->uv_normed.sign, m->uv_normed.planes, m->uv_normed.scale, c->a_planes,
m->k_float, m->acc_buf);
pure_unary_matvec(&layer->v_proj,
m->uv_normed.sign, m->uv_normed.planes, m->uv_normed.scale, c->a_planes,
m->v_float, m->acc_buf);
/* QK-Norm (per head) */
if (layer->q_norm) {
for (int h = 0; h < c->n_heads; h++)
rmsnorm_head(m->q_float + h * c->head_dim, layer->q_norm,
m->q_float + h * c->head_dim, c->head_dim);
}
if (layer->k_norm) {
for (int h = 0; h < c->n_kv_heads; h++)
rmsnorm_head(m->k_float + h * c->head_dim, layer->k_norm,
m->k_float + h * c->head_dim, c->head_dim);
}
/* RoPE */
for (int h = 0; h < c->n_heads; h++)
apply_rope(m->q_float + h * c->head_dim, pos, c->head_dim, c->rope_theta);
for (int h = 0; h < c->n_kv_heads; h++)
apply_rope(m->k_float + h * c->head_dim, pos, c->head_dim, c->rope_theta);
/* Store K, V to cache */
for (int h = 0; h < c->n_kv_heads; h++) {
memcpy(kv_ptr(m->k_cache, c, layer_idx, pos, h),
m->k_float + h * c->head_dim, c->head_dim * sizeof(float));
memcpy(kv_ptr(m->v_cache, c, layer_idx, pos, h),
m->v_float + h * c->head_dim, c->head_dim * sizeof(float));
}
/* Attention scores + weighted sum (O(seq × head_dim), not O(dim²)) */
float scale = 1.0f / sqrtf((float)c->head_dim);
memset(m->attn_out, 0, c->n_heads * c->head_dim * sizeof(float));
for (int h = 0; h < c->n_heads; h++) {
int kv_h = h / heads_per_kv;
float *q_head = m->q_float + h * c->head_dim;
float *out_head = m->attn_out + h * c->head_dim;
for (int t = 0; t <= pos; t++) {
float *k_cached = kv_ptr(m->k_cache, c, layer_idx, t, kv_h);
float dot = 0.0f;
for (int d = 0; d < c->head_dim; d++)
dot += q_head[d] * k_cached[d];
m->attn_scores[t] = dot * scale;
}
softmax(m->attn_scores, pos + 1);
for (int t = 0; t <= pos; t++) {
float w = m->attn_scores[t];
if (w < 1e-8f) continue;
float *v_cached = kv_ptr(m->v_cache, c, layer_idx, t, kv_h);
for (int d = 0; d < c->head_dim; d++)
out_head[d] += w * v_cached[d];
}
}
/* O projection - quantize attn_out, then pure unary */
int o_in = c->n_heads * c->head_dim;
UnaryVec uv_attn;
uv_alloc(&uv_attn, o_in, c->a_planes);
quantize_to_unary(m->attn_out, o_in, c->a_planes,
uv_attn.sign, uv_attn.planes, &uv_attn.scale);
/* Temp buffer for O projection output */
float *o_out = m->normed; /* reuse normed buffer */
pure_unary_matvec(&layer->o_proj,
uv_attn.sign, uv_attn.planes, uv_attn.scale, c->a_planes,
o_out, m->acc_buf);
/* Copy o_out to where caller expects it (normed acts as temp) */
memcpy(m->attn_out, o_out, c->hidden * sizeof(float));
free(uv_attn.sign);
free(uv_attn.planes);
}
/* ============================================================
* MLP (using pure unary for all projections)
* ============================================================ */
static void mlp(Model *m, int layer_idx) {
Config *c = &m->cfg;
Layer *layer = &m->layers[layer_idx];
/* Quantize normed input */
quantize_to_unary(m->normed, c->hidden, c->a_planes,
m->uv_mlp_in.sign, m->uv_mlp_in.planes, &m->uv_mlp_in.scale);
/* Gate and Up projections - PURE UNARY */
pure_unary_matvec(&layer->gate_proj,
m->uv_mlp_in.sign, m->uv_mlp_in.planes, m->uv_mlp_in.scale, c->a_planes,
m->gate_float, m->acc_buf);
pure_unary_matvec(&layer->up_proj,
m->uv_mlp_in.sign, m->uv_mlp_in.planes, m->uv_mlp_in.scale, c->a_planes,
m->up_float, m->acc_buf);
/* SiLU(gate) * up - O(inter) float op */
silu_mul(m->gate_float, m->up_float, m->mlp_act, c->inter);
/* Quantize for down projection */
quantize_to_unary(m->mlp_act, c->inter, c->a_planes,
m->uv_mlp_act.sign, m->uv_mlp_act.planes, &m->uv_mlp_act.scale);
/* Down projection - PURE UNARY */
pure_unary_matvec(&layer->down_proj,
m->uv_mlp_act.sign, m->uv_mlp_act.planes, m->uv_mlp_act.scale, c->a_planes,
m->normed, m->acc_buf); /* reuse normed as output */
}
/* ============================================================
* FORWARD ONE TOKEN
* ============================================================ */
float* forward_token(Model *m, int token_id, int pos) {
Config *c = &m->cfg;
embed_token(m->embed, token_id, m->hidden, c->hidden);
for (int l = 0; l < c->n_layers; l++) {
/* Pre-attention norm */
rmsnorm(m->hidden, m->layers[l].input_norm, m->normed, c->hidden);
/* Attention (quantizes normed internally, outputs to attn_out) */
attention(m, l, pos);
vec_add(m->hidden, m->attn_out, c->hidden);
/* Post-attention norm */
rmsnorm(m->hidden, m->layers[l].post_norm, m->normed, c->hidden);
/* MLP (quantizes normed internally, outputs to normed) */
mlp(m, l);
vec_add(m->hidden, m->normed, c->hidden);
}
/* Final norm */
rmsnorm(m->hidden, m->final_norm, m->normed, c->hidden);
/* LM head - FP16 for now (vocab projection is O(vocab × hidden), not repeated per-layer) */
if (c->tie_embeddings) {
fp16_matvec(m->embed, m->normed, m->logits, c->vocab, c->hidden);
}
return m->logits;
}
/* ============================================================
* SAMPLING
* ============================================================ */
static int sample_top_p(float *logits, int vocab, float temperature, float top_p) {
if (temperature > 0) {
float inv_t = 1.0f / temperature;
for (int i = 0; i < vocab; i++) logits[i] *= inv_t;
}
softmax(logits, vocab);
int n_keep = 0;
float cum = 0.0f;
float *probs = (float *)malloc(vocab * sizeof(float));
int *indices = (int *)malloc(vocab * sizeof(int));
memcpy(probs, logits, vocab * sizeof(float));
for (int i = 0; i < vocab; i++) indices[i] = i;
while (cum < top_p && n_keep < vocab) {
int best = n_keep;
for (int i = n_keep + 1; i < vocab; i++)
if (probs[i] > probs[best]) best = i;
float tmp = probs[n_keep]; probs[n_keep] = probs[best]; probs[best] = tmp;
int ti = indices[n_keep]; indices[n_keep] = indices[best]; indices[best] = ti;
cum += probs[n_keep];
n_keep++;
if (n_keep >= 40) break;
}
float sum = 0.0f;
for (int i = 0; i < n_keep; i++) sum += probs[i];
float r = (float)rand() / RAND_MAX * sum;
float acc = 0.0f;
int chosen = indices[0];
for (int i = 0; i < n_keep; i++) {
acc += probs[i];
if (acc >= r) { chosen = indices[i]; break; }
}
free(probs); free(indices);
return chosen;
}
int generate(
Model *m,
const int *prompt_ids, int prompt_len,
int *out_tokens, int max_new_tokens,
float temperature, float top_p, int eos_token
) {
srand(time(NULL));
for (int i = 0; i < prompt_len; i++)
forward_token(m, prompt_ids[i], i);
int pos = prompt_len;
int generated = 0;
for (int t = 0; t < max_new_tokens; t++) {
int next;
if (temperature <= 0) {
next = 0;
for (int i = 1; i < m->cfg.vocab; i++)
if (m->logits[i] > m->logits[next]) next = i;
} else {
next = sample_top_p(m->logits, m->cfg.vocab, temperature, top_p);
}
out_tokens[t] = next;
generated++;
if (next == eos_token) break;
forward_token(m, next, pos);
pos++;
}
return generated;
}
/* ============================================================
* ALLOCATION
* ============================================================ */
Model* model_alloc(
int w_planes, int a_planes,
int hidden, int inter, int n_heads, int n_kv_heads,
int head_dim, int n_layers, int vocab,
float rope_theta, int tie_embeddings
) {
Model *m = (Model *)calloc(1, sizeof(Model));
Config *c = &m->cfg;
c->hidden = hidden; c->inter = inter;
c->n_heads = n_heads; c->n_kv_heads = n_kv_heads;
c->head_dim = head_dim; c->n_layers = n_layers;
c->vocab = vocab; c->rope_theta = rope_theta;
c->tie_embeddings = tie_embeddings;
c->w_planes = w_planes; c->a_planes = a_planes;
m->layers = (Layer *)calloc(n_layers, sizeof(Layer));
size_t kv_size = (size_t)n_layers * MAX_SEQ * n_kv_heads * head_dim;
m->k_cache = (float *)calloc(kv_size, sizeof(float));
m->v_cache = (float *)calloc(kv_size, sizeof(float));
m->hidden = (float *)aligned_alloc(64, hidden * sizeof(float));
m->normed = (float *)aligned_alloc(64, (inter > hidden ? inter : hidden) * sizeof(float));
m->q_float = (float *)aligned_alloc(64, n_heads * head_dim * sizeof(float));
m->k_float = (float *)aligned_alloc(64, n_kv_heads * head_dim * sizeof(float));
m->v_float = (float *)aligned_alloc(64, n_kv_heads * head_dim * sizeof(float));
m->attn_out = (float *)aligned_alloc(64, n_heads * head_dim * sizeof(float));
m->gate_float = (float *)aligned_alloc(64, inter * sizeof(float));
m->up_float = (float *)aligned_alloc(64, inter * sizeof(float));
m->mlp_act = (float *)aligned_alloc(64, inter * sizeof(float));
m->logits = (float *)aligned_alloc(64, vocab * sizeof(float));
m->attn_scores = (float *)aligned_alloc(64, MAX_SEQ * sizeof(float));
m->final_norm = (float *)aligned_alloc(64, hidden * sizeof(float));
m->acc_buf = (int *)aligned_alloc(64, (inter > vocab ? inter : vocab) * sizeof(int));
/* Unary vector scratch */
uv_alloc(&m->uv_normed, hidden, a_planes);
uv_alloc(&m->uv_mlp_in, hidden, a_planes);
uv_alloc(&m->uv_mlp_act, inter, a_planes);
size_t kv_mb = kv_size * 2 * sizeof(float) / (1024*1024);
printf("PURE UNARY ENGINE\n");
printf(" Model: hidden=%d inter=%d heads=%d/%d layers=%d vocab=%d\n",
hidden, inter, n_heads, n_kv_heads, n_layers, vocab);
printf(" Weight planes: %d, Activation planes: %d\n", w_planes, a_planes);
printf(" Plane pairs per matvec element: %d\n", w_planes * a_planes);
printf(" KV cache: %zu MB\n", kv_mb);
printf(" Float ops: RMSNorm, SiLU, Softmax, RoPE, residual (all O(dim))\n");
printf(" Integer ops: ALL matmuls (O(dim²) — the actual bottleneck)\n");
return m;
}
/* Weight setters (same interface as v2) */
void model_set_embed(Model *m, uint16_t *data) { m->embed = data; }
void model_set_final_norm(Model *m, float *data) { memcpy(m->final_norm, data, m->cfg.hidden * sizeof(float)); }
void layer_set_norms(Model *m, int l, float *in_norm, float *post_norm) {
m->layers[l].input_norm = in_norm;
m->layers[l].post_norm = post_norm;
}
void layer_set_qk_norm(Model *m, int l, float *q_norm, float *k_norm) {
m->layers[l].q_norm = q_norm;
m->layers[l].k_norm = k_norm;
}
static void init_unary_weight(
UnaryWeight *uw,
uint64_t *sign, uint64_t *planes, float *scales,
int out_dim, int in_dim, int n_planes
) {
uw->sign_bits = sign;
uw->mag_planes = planes;
uw->scales = scales;
uw->out_dim = out_dim;
uw->in_dim = in_dim;
uw->n_planes = n_planes;
uw->chunks = (in_dim + 63) / 64;
}
void layer_set_linears(
Model *m, int l,
uint64_t *q_s, uint64_t *q_p, float *q_sc, int q_out, int q_in,
uint64_t *k_s, uint64_t *k_p, float *k_sc, int k_out, int k_in,
uint64_t *v_s, uint64_t *v_p, float *v_sc, int v_out, int v_in,
uint64_t *o_s, uint64_t *o_p, float *o_sc, int o_out, int o_in,
uint64_t *g_s, uint64_t *g_p, float *g_sc, int g_out, int g_in,
uint64_t *u_s, uint64_t *u_p, float *u_sc, int u_out, int u_in,
uint64_t *d_s, uint64_t *d_p, float *d_sc, int d_out, int d_in,
int n_planes
) {
init_unary_weight(&m->layers[l].q_proj, q_s, q_p, q_sc, q_out, q_in, n_planes);
init_unary_weight(&m->layers[l].k_proj, k_s, k_p, k_sc, k_out, k_in, n_planes);
init_unary_weight(&m->layers[l].v_proj, v_s, v_p, v_sc, v_out, v_in, n_planes);
init_unary_weight(&m->layers[l].o_proj, o_s, o_p, o_sc, o_out, o_in, n_planes);
init_unary_weight(&m->layers[l].gate_proj, g_s, g_p, g_sc, g_out, g_in, n_planes);
init_unary_weight(&m->layers[l].up_proj, u_s, u_p, u_sc, u_out, u_in, n_planes);
init_unary_weight(&m->layers[l].down_proj, d_s, d_p, d_sc, d_out, d_in, n_planes);
}
void model_reset_cache(Model *m) {
size_t kv_size = (size_t)m->cfg.n_layers * MAX_SEQ * m->cfg.n_kv_heads * m->cfg.head_dim;
memset(m->k_cache, 0, kv_size * sizeof(float));
memset(m->v_cache, 0, kv_size * sizeof(float));
}