#pragma once #include #include #include #include #include "config.h" extern char _embedded_binary_model[]; typedef float fp32_t; typedef struct MambaWeights MambaWeights; typedef struct MambaConfig MambaConfig; typedef struct MambaState MambaState; typedef struct Mamba Mamba; static void MambaLog(Mamba *); static fp32_t *MambaForwardLayer(Mamba *, uint64_t); static fp32_t *MambaForward(Mamba *, uint64_t); #define Tensor(NAME, X, Y, Z) fp32_t NAME[(X) * (Y) * (Z)] struct MambaWeights { Tensor(embed, ROUNDED_VOCAB_SIZE, D_MODEL, 1); Tensor(in_proj, N_LAYER, 2 * D_INNER, D_MODEL); Tensor(conv1d_weight, N_LAYER, D_INNER, D_CONV); Tensor(conv1d_bias, N_LAYER, D_INNER, 1); Tensor(x_proj, N_LAYER, DT_RANK + 2 * D_STATE, D_INNER); Tensor(dt_proj_weight, N_LAYER, D_INNER, DT_RANK); Tensor(dt_proj_bias, N_LAYER, D_INNER, 1); Tensor(A, N_LAYER, D_INNER, D_STATE); Tensor(D, N_LAYER, D_INNER, 1); Tensor(out_proj, N_LAYER, D_MODEL, D_INNER); Tensor(norm, N_LAYER, D_MODEL, 1); Tensor(norm_f, D_MODEL, 1, 1); Tensor(lm_head, ROUNDED_VOCAB_SIZE, D_MODEL, 1); }; struct MambaConfig { uint64_t vocab_size; // vocabulary size, rounded to nearest multiple of 8 uint64_t n_layer; // number of layers uint64_t d_model; // embedding dim uint64_t d_inner; uint64_t dt_rank; uint64_t d_state; uint64_t d_conv; }; struct MambaState { Tensor(hidden_state, D_MODEL, 1, 1); Tensor(conv_state, N_LAYER, D_INNER, D_CONV); Tensor(ssm_state, N_LAYER, D_INNER, D_STATE); }; struct Mamba { MambaConfig config; MambaState state; MambaWeights *weights; void (*log) (Mamba *); fp32_t *(*forward_layer) (Mamba *, uint64_t); fp32_t *(*forward) (Mamba *, uint64_t); }; Mamba mamba = { .config = { .vocab_size = ROUNDED_VOCAB_SIZE, .n_layer = N_LAYER, .d_model = D_MODEL, .d_inner = D_INNER, .dt_rank = DT_RANK, .d_state = D_STATE, .d_conv = D_CONV, }, .state = {}, .weights = (MambaWeights *) _embedded_binary_model, .log = MambaLog, .forward_layer = MambaForwardLayer, .forward = MambaForward }; static void rmsnorm(fp32_t* y, fp32_t* x, fp32_t* weight, uint64_t size) { fp32_t ss = 0.0f; for (uint64_t j = 0; j < size; ++j) ss += x[j] * x[j]; ss /= size; ss += 1e-5f; ss = 1.0f / sqrtf(ss); for (uint64_t j = 0; j < size; ++j) y[j] = x[j] * weight[j] * ss; } static fp32_t softplus(fp32_t x) { return logf(1.0f + expf(x)); } static fp32_t sigmoid(fp32_t x) { return 1.0f / (1.0f + expf(-x)); } static fp32_t silu(fp32_t x) { return x * sigmoid(x); } static void shift_and_update_last_column(fp32_t* matrix, fp32_t* x, uint64_t rows, uint64_t cols) { #pragma omp parallel for for (uint64_t i = 0; i < rows; ++i) { fp32_t* row = matrix + i * cols; for (uint64_t j = 0; j < cols - 1; ++j) row[j] = row[j + 1]; row[cols - 1] = x[i]; } } static void conv1d_silu(fp32_t* x, fp32_t* conv_state, fp32_t* conv1d_weight, fp32_t* conv1d_bias, uint64_t d_inner, uint64_t d_conv) { #pragma omp parallel for for (uint64_t i = 0; i < d_inner; ++i) { fp32_t val = 0.0f; for (uintconv1d_silu64_t j = 0; j < d_conv; ++j) { uint64_t index = i * d_conv + j; val += conv_state[index] * conv1d_weight[index]; } x[i] = silu(val + conv1d_bias[i]); } } static void dense_softplus(fp32_t* y, fp32_t* x, fp32_t* w, fp32_t* b, uint64_t d, uint64_t n) { #pragma omp parallel for for (uint64_t i = 0; i < d; ++i) { fp32_t val = 0.0f; for (uint64_t j = 0; j < n; ++j) val += w[i * n + j] * x[j]; y[i] = softplus(val + b[i]); } } static void selective_scan(fp32_t* y, fp32_t* ssm_state, fp32_t* dt, fp32_t* A, fp32_t* B, fp32_t* C, fp32_t* D, fp32_t* x, fp32_t* z, uint64_t d_inner, uint64_t d_state) { #pragma omp parallel for for (uint64_t i = 0; i < d_inner; ++i) { fp32_t val = 0.0f; for (uint64_t j = 0; j < d_state; ++j) { uint64_t index = i * d_state + j; fp32_t dA = expf(dt[i] * A[index]); fp32_t dB = dt[i] * B[j]; ssm_state[index] = ssm_state[index] * dA + x[i] * dB; val += ssm_state[index] * C[j]; } val += D[i] * x[i]; y[i] = val * silu(z[i]); } } static void matmul(fp32_t* y, fp32_t* x, fp32_t* w, uint64_t d, uint64_t n) { #pragma omp parallel for for (uint64_t i = 0; i < d; ++i) { fp32_t val = 0.0f; for (uint64_t j = 0; j < n; ++j) val += w[i * n + j] * x[j]; y[i] = val; } } #define LOG(...) fprintf(stderr, __VA_ARGS__) #define CLOG(X, ...) if(X) LOG(__VA_ARGS__) static char *shortScale(char *result, int64_t number) { char *suffixes[] = {"", "k", "m", "b"}; uint64_t magnitude = 0; fp32_t num = (fp32_t)number; if (number < 1000) { sprintf(result, "%lu", number); return result; } while (number >= 1000 || number <= -1000) { magnitude++; number /= 1000; num /= 1000.0; } sprintf(result, "%.0f%s", num, suffixes[magnitude]); return result; } static inline void nparams(uint64_t dim) { char buff[16]; shortScale(buff, dim); LOG("%12lu (%s)", dim, buff); } #define NPARAMS(X) nparams(sizeof(X) / sizeof(fp32_t)) static void MambaLog(Mamba *mamba) { MambaConfig *p = &mamba->config; MambaWeights *w = mamba->weights; MambaState *s = &mamba->state; LOG("Mamba Config:"); LOG("\n\tvocab_size: %12lu", p->vocab_size); LOG("\n\tn_layer: %12lu", p->n_layer); LOG("\n\td_model: %12lu", p->d_model); LOG("\n\td_inner: %12lu", p->d_inner); LOG("\n\tdt_rank: %12lu", p->dt_rank); LOG("\n\td_state: %12lu", p->d_state); LOG("\n\td_conv: %12lu", p->d_conv); printf("\n\n\n"); LOG("Parameters Count:"); LOG("\n\tembed: "); NPARAMS(w->embed); LOG("\n\tin_proj: "); NPARAMS(w->in_proj); LOG("\n\tconv1d_weight: "); NPARAMS(w->conv1d_weight); LOG("\n\tconv1d_bias: "); NPARAMS(w->conv1d_bias); LOG("\n\tx_proj: "); NPARAMS(w->x_proj); LOG("\n\tdt_proj_weight: "); NPARAMS(w->dt_proj_weight); LOG("\n\tdt_proj_bias: "); NPARAMS(w->dt_proj_bias); LOG("\n\tA: "); NPARAMS(w->A); LOG("\n\tD: "); NPARAMS(w->D); LOG("\n\tout_proj: "); NPARAMS(w->out_proj); LOG("\n\tnorm: "); NPARAMS(w->norm); LOG("\n\tnorm_f: "); NPARAMS(w->norm_f); LOG("\n\tlm_head: "); NPARAMS(w->lm_head); LOG("\n\n\tTotal: "); NPARAMS(MambaWeights); //NTotal(); printf("\n\n\n"); LOG("Recurrent State:"); LOG("\n\thidden_state: "); NPARAMS(s->hidden_state); LOG("\n\tconv_state: "); NPARAMS(s->conv_state); LOG("\n\tssm_state: "); NPARAMS(s->ssm_state); LOG("\n\n\tTotal: "); NPARAMS(MambaState); printf("\n\n\n"); } static fp32_t *MambaForwardLayer(Mamba *mamba, uint64_t layer) { MambaConfig *p = &mamba->config; MambaWeights *w = mamba->weights; MambaState *s = &mamba->state; uint64_t d_model = p->d_model, d_inner = p->d_inner, d_conv = p->d_conv, d_state = p->d_state, dt_rank = p->dt_rank; fp32_t *hidden_state = s->hidden_state, *conv_state = s->conv_state + layer * d_inner * d_conv, *ssm_state = s->ssm_state + layer * d_inner * d_state; Tensor( xz, 2 * D_INNER, 1, 1); Tensor( x_db, DT_RANK + 2 * D_STATE, 1, 1); Tensor( dt, D_INNER, 1, 1); Tensor( y, D_INNER, 1, 1); fp32_t *x = xz, *z = xz + d_inner, *B = x_db + dt_rank, *C = x_db + dt_rank + d_state; // Proj input matmul(xz, hidden_state, w->in_proj + layer * 2 * d_inner * d_model, 2 * d_inner, d_model); // Conv shift_and_update_last_column(conv_state, x, d_inner, d_conv); conv1d_silu(x, conv_state, w->conv1d_weight + layer * d_inner * d_conv, w->conv1d_bias + layer * d_inner, d_inner, d_conv); // SSM matmul(x_db, x, w->x_proj + layer * (dt_rank + 2 * d_state) * d_inner, dt_rank + 2 * d_state, d_inner); dense_softplus(dt, x_db, w->dt_proj_weight + layer * d_inner * dt_rank, w->dt_proj_bias + layer * d_inner, d_inner, dt_rank); selective_scan(y, ssm_state, dt, w->A + layer * d_inner * d_state, B, C, w->D + layer * d_inner, x, z, d_inner, d_state); // Proj output matmul(hidden_state, y, w->out_proj + layer * d_model * d_inner, d_model, d_inner); return hidden_state; } static fp32_t *MambaForward(Mamba *mamba, uint64_t token) { MambaConfig *p = &mamba->config; MambaWeights *w = mamba->weights; MambaState *s = &mamba->state; uint64_t d_model = p->d_model, n_layer = p->n_layer, vocab_size = p->vocab_size; Tensor( input, D_MODEL, 1, 1); static Tensor( logits, ROUNDED_VOCAB_SIZE, 1, 1); fp32_t *hidden_state = s->hidden_state, *content_row = w->embed + token * d_model; memcpy(input, content_row, d_model * sizeof(fp32_t)); for (uint64_t layer = 0; layer < n_layer; ++layer) { rmsnorm(hidden_state, input, w->norm + layer * d_model, d_model); mamba->forward_layer(mamba, layer); for (uint64_t i = 0; i < d_model; ++i) { hidden_state[i] += input[i]; input[i] = hidden_state[i]; } } rmsnorm(hidden_state, hidden_state, w->norm_f, d_model); matmul(logits, hidden_state, w->lm_head, vocab_size, d_model); return logits; }