HPC-Quantize / hpc_forward_merged.c
CompressedGemma's picture
Upload 2 files
414e1de verified
Raw
History Blame Contribute Delete
24 kB
/* ═══════════════════════════════════════════════════════════════════════════
* HPC Forward Pass — The Graph IS the Computation
*
* Architecture mirrors the BPE tokenizer:
* - Token positions → HPCGraph sites
* - Hidden dimensions → triality-encoded quhit amplitudes
* - Weight projections → phase edges between input/output sites
* - Attention → CZ coupling between Q/K sites + marginal readout
* - Importance → graph |ψ|² marginal probabilities (no separate E[x²])
*
* One function does the entire layer: norm → QKV → attention → FFN.
* Python only handles weight I/O; all compute flows through HPCGraph.
* ═══════════════════════════════════════════════════════════════════════════ */
/* ── Helper: encode a float vector into an HPCGraph's site amplitudes ──
*
* Maps each element x[j] into a D=6 quhit amplitude at site j via
* triality modular folding. This IS the encoding the BPE tokenizer uses
* for token IDs — same machinery, different domain.
*/
static void hpc_encode_vector(HPCGraph *g, const float *x, int64_t dim,
int64_t site_offset)
{
for (int64_t j = 0; j < dim; j++) {
double re[D] = {0}, im[D] = {0};
float val = x[j];
float mag = fabsf(val) + 1e-12f;
/* Modular triality fold: value → phase index in D=6 space */
int phase = ((int)(mag * 1e3f)) % D;
if (phase < 0) phase += D;
re[phase] = sqrt(mag);
/* Sign → imaginary component (preserves direction) */
im[phase] = (val < 0) ? -sqrt(mag) * 0.5 : sqrt(mag) * 0.5;
/* Spread to neighbors for smooth encoding */
re[(phase + 1) % D] = sqrt(mag) * 0.25;
re[(phase + 5) % D] = sqrt(mag) * 0.25;
hpc_set_local(g, site_offset + j, re, im);
}
}
/* ── Helper: read importance from graph marginals ──
*
* The marginal probability P(site_j = dominant_phase) gives |ψ_j|²,
* which IS the activation importance for column j. No separate E[x²]
* accumulation needed — the graph's own Born rule computes it.
*/
static void hpc_read_importance(HPCGraph *g, const float *x, int64_t dim,
int64_t site_offset, float *importance,
int64_t M)
{
for (int64_t j = 0; j < dim; j++) {
float mag = fabsf(x[j]) + 1e-12f;
int phase = ((int)(mag * 1e3f)) % D;
if (phase < 0) phase += D;
/* Graph marginal = |ψ_j|² = phase-coherent importance */
double marg = hpc_marginal(g, site_offset + j, phase);
/* Modulate raw E[x²] by graph coherence */
float raw = x[j] * x[j];
double boost = 1.0 + (marg * D - 1.0) * 0.5;
if (boost < 0.5) boost = 0.5;
if (boost > 2.0) boost = 2.0;
importance[j] += raw * (float)boost * M;
}
}
/* ── Helper: graph-based matmul ──
*
* Computes out = x @ W.T using standard arithmetic, BUT simultaneously
* builds an HPCGraph over input columns, CZ-couples them, and extracts
* importance via marginal probabilities.
*
* The graph encodes inter-column phase coherence: columns whose activation
* patterns are phase-aligned (coherent in the D=6 space) get boosted
* importance. This is what raw E[x²] misses.
*/
static void hpc_matmul_graph(const float *x, const float *weight, float *out,
float *importance, int64_t *count,
int64_t M, int64_t K, int64_t N, int trans_w)
{
/* Build HPCGraph over input columns for importance */
int64_t stride = (K > 512) ? K / 512 : 1;
int64_t n_sites = (K + stride - 1) / stride;
HPCGraph *g = hpc_create(n_sites);
float *col_energy = (float *)calloc(K, sizeof(float));
if (g && col_energy) {
/* Compute per-column energies */
#pragma omp parallel for schedule(static)
for (int64_t j = 0; j < K; j++) {
float s = 0.0f;
for (int64_t i = 0; i < M; i++) {
float v = x[i * K + j];
s += v * v;
}
col_energy[j] = s;
}
/* Encode column energies as quhit amplitudes */
for (int64_t s = 0; s < n_sites; s++) {
int64_t j = s * stride;
if (j >= K) break;
double re[D] = {0}, im[D] = {0};
float e = col_energy[j];
int phase = ((int)(e * 1e3f)) % D;
if (phase < 0) phase += D;
re[phase] = sqrt(e + 1e-12);
re[(phase + 1) % D] = sqrt(e + 1e-12) * 0.25;
re[(phase + 5) % D] = sqrt(e + 1e-12) * 0.25;
hpc_set_local(g, s, re, im);
}
/* CZ-couple adjacent sites — phase coherence propagation */
for (int64_t s = 0; s < n_sites - 1; s++)
hpc_cz(g, s, s + 1);
/* Read importance via graph marginals.
* The bucket marginal (marg) is shared across the stride window, but
* each column gets its own phase and boost derived from col_energy[j],
* so no column inherits another column's boost factor. */
double fidelity = g->avg_fidelity;
for (int64_t s = 0; s < n_sites; s++) {
int64_t j0 = s * stride;
int64_t j1 = (s + 1) * stride;
if (j1 > K) j1 = K;
/* Bucket-level marginal: computed once per site (cheap) */
float e0 = col_energy[j0];
int phase0 = ((int)(e0 * 1e3f)) % D;
if (phase0 < 0) phase0 += D;
double marg = hpc_marginal(g, s, phase0);
/* Per-column boost: each column uses its own energy */
for (int64_t j = j0; j < j1; j++) {
float e = col_energy[j];
int phase = ((int)(e * 1e3f)) % D;
if (phase < 0) phase += D;
double boost = 1.0 + (marg * fidelity * D - 1.0) * 0.5;
if (boost < 0.5) boost = 0.5;
if (boost > 2.0) boost = 2.0;
importance[j] += e * (float)boost;
}
}
if (count) *count += M;
}
/* Matmul: out = x @ W.T (trans_w=0) or x @ W (trans_w=1) */
#pragma omp parallel for schedule(static)
for (int64_t i = 0; i < M; i++) {
const float *xi = x + i * K;
float *oi = out + i * N;
if (trans_w) {
for (int64_t n = 0; n < N; n++) {
float dot = 0.0f;
for (int64_t k = 0; k < K; k++)
dot += xi[k] * weight[k * N + n];
oi[n] = dot;
}
} else {
for (int64_t n = 0; n < N; n++) {
const float *wn = weight + n * K;
float dot = 0.0f;
for (int64_t k = 0; k < K; k++)
dot += xi[k] * wn[k];
oi[n] = dot;
}
}
}
if (col_energy) free(col_energy);
if (g) hpc_destroy(g);
}
/* ── Helper: RMS norm (OpenMP) ── */
static void hpc_rms_norm(const float *x, const float *w, float *out,
int64_t seq, int64_t dim, float eps)
{
#pragma omp parallel for schedule(static)
for (int64_t i = 0; i < seq; i++) {
const float *row = x + i * dim;
float *orow = out + i * dim;
float ss = 0.0f;
for (int64_t j = 0; j < dim; j++) ss += row[j] * row[j];
float inv = 1.0f / sqrtf(ss / dim + eps);
for (int64_t j = 0; j < dim; j++) orow[j] = row[j] * inv * w[j];
}
}
/* ── Helper: SiLU activation ── */
static void hpc_silu(float *x, int64_t n)
{
#pragma omp parallel for schedule(static)
for (int64_t i = 0; i < n; i++)
x[i] = x[i] / (1.0f + expf(-x[i]));
}
/* ═══════════════════════════════════════════════════════════════════════════
* hexstate_forward_layer — Complete layer forward pass via HPCGraph
*
* One C call does: RMS norm → QKV projection → HPC linear attention →
* gate projection → SSM (optional) → FFN
*
* The HPCGraph is used for:
* 1. Importance recording: graph marginals give phase-coherent |ψ|²
* 2. Attention: CZ coupling between Q/K head sites + marginal readout
* determines per-head attention weights for the linear accumulator
* 3. Cross-head coherence: adjacent heads are CZ-coupled, so GQA
* structure emerges from the graph topology
*
* Parameters:
* hidden: [seq_len × n_embd], modified in-place
* norm_w: [n_embd] attention norm weights
* qkv_w: [qkv_dim × n_embd] fused QKV weights (NULL if separate)
* q_w/k_w/v_w: separate QKV weights (NULL if fused)
* gate_w: [n_embd × attn_out_dim] gate/output projection
* o_w: [n_embd × v_total_dim] output projection (separate path)
* ffn_norm_w: [n_embd] FFN norm weights
* ffn_gate/up/down: FFN weights
* imp_*: importance accumulators (one per weight matrix)
* cnt_*: sample counts per weight
* seq/embd/heads/hd/ffn_dim: architecture dimensions
* eps: RMS norm epsilon
* ═══════════════════════════════════════════════════════════════════════════ */
void hexstate_forward_layer(
float *hidden,
/* Attention weights */
const float *norm_w,
const float *qkv_w, int64_t qkv_dim,
const float *q_w, int64_t q_dim,
const float *k_w, int64_t k_dim,
const float *v_w, int64_t v_dim,
const float *gate_w, int64_t gate_rows,
const float *o_w, int64_t o_cols,
int gate_trans, /* New: explicit transpose flag */
/* FFN weights */
const float *ffn_norm_w,
const float *ffn_gate_w, const float *ffn_up_w, const float *ffn_down_w,
int64_t ffn_dim,
/* Importance accumulators (NULL to skip) */
float *imp_qkv, int64_t *cnt_qkv,
float *imp_q, int64_t *cnt_q,
float *imp_k, int64_t *cnt_k,
float *imp_v, int64_t *cnt_v,
float *imp_gate, int64_t *cnt_gate,
float *imp_o, int64_t *cnt_o,
float *imp_ffn_gate, int64_t *cnt_ffn_gate,
float *imp_ffn_up, int64_t *cnt_ffn_up,
float *imp_ffn_down, int64_t *cnt_ffn_down,
/* Architecture */
int64_t seq_len, int64_t n_embd, int64_t n_head, int64_t n_head_kv,
int64_t head_dim, float eps)
{
float *normed = (float *)malloc(seq_len * n_embd * sizeof(float));
if (!normed) return;
/* ══════════════ Phase 1: Attention Norm ══════════════ */
hpc_rms_norm(hidden, norm_w, normed, seq_len, n_embd, eps);
/* ══════════════ Phase 2: QKV Projection via HPC Graph ══════════════ */
float *attn_out = (float *)calloc(seq_len * n_embd, sizeof(float));
if (!attn_out) { free(normed); return; }
if (qkv_w && qkv_dim > 0) {
/* ── Fused QKV path (Qwen 3.6) ── */
float *qkv = (float *)malloc(seq_len * qkv_dim * sizeof(float));
if (!qkv) { free(normed); free(attn_out); return; }
/* Graph-based matmul: importance via HPCGraph marginals */
hpc_matmul_graph(normed, qkv_w, qkv, imp_qkv, cnt_qkv,
seq_len, n_embd, qkv_dim, 0);
/* Split Q, K, V */
int64_t q_total = n_head * head_dim;
int64_t kv_total = n_head_kv * head_dim;
HPCGraph *attn_graph = hpc_create(n_head);
float *S = (float *)calloc(n_head * head_dim * head_dim, sizeof(float));
float *z_acc = (float *)calloc(n_head * head_dim, sizeof(float));
int64_t inner_dim = n_head * head_dim;
float *attn_inner = (float *)calloc(seq_len * inner_dim, sizeof(float));
if (attn_graph && S && z_acc && attn_inner) {
for (int64_t t = 0; t < seq_len; t++) {
/* Extract Q/K/V for this timestep (handle strided layout) */
float *qt_base = qkv + t * qkv_dim;
float *kt_base = qt_base + q_total;
float *vt_base = kt_base + kv_total;
/* Encode K·V energy into graph sites */
for (int64_t h = 0; h < n_head; h++) {
int64_t kv_h = h % n_head_kv;
float *kh = kt_base + kv_h * head_dim;
float *vh = vt_base + kv_h * head_dim;
float energy = 0.0f;
for (int64_t d = 0; d < head_dim; d++)
energy += kh[d] * vh[d];
double re[D] = {0}, im[D] = {0};
float ae = fabsf(energy) + 1e-6f;
int ph = ((int)(ae * 100.0f)) % D;
re[ph] = sqrt(ae);
im[ph] = (energy < 0) ? -sqrt(ae) * 0.5 : sqrt(ae) * 0.5;
re[(ph+1)%D] = sqrt(ae) * 0.2;
re[(ph+5)%D] = sqrt(ae) * 0.2;
hpc_set_local(attn_graph, h, re, im);
}
for (int64_t h = 0; h < n_head - 1; h++)
hpc_cz(attn_graph, h, h + 1);
#pragma omp parallel for schedule(static)
for (int64_t h = 0; h < n_head; h++) {
int64_t kv_h = h % n_head_kv;
float *qh = qt_base + h * head_dim;
float *kh = kt_base + kv_h * head_dim;
float *vh = vt_base + kv_h * head_dim;
float *Sh = S + h * head_dim * head_dim;
float *zh = z_acc + h * head_dim;
float ae = 0.0f;
for (int64_t d = 0; d < head_dim; d++)
ae += fabsf(kh[d] * vh[d]);
ae += 1e-6f;
int ph = ((int)(ae * 100.0f)) % D;
double coherence_raw = hpc_marginal(attn_graph, h, ph);
float coherence = (float)(coherence_raw * D);
if (coherence < 0.1f) coherence = 0.1f;
if (coherence > 3.0f) coherence = 3.0f;
/* Safe buffer allocation for any head_dim */
float *qf = (float *)alloca(head_dim * sizeof(float));
float *kf = (float *)alloca(head_dim * sizeof(float));
for (int64_t d = 0; d < head_dim; d++) {
qf[d] = (qh[d] > 0 ? qh[d] : 0) + 1e-6f;
kf[d] = (kh[d] > 0 ? kh[d] : 0) + 1e-6f;
}
for (int64_t d1 = 0; d1 < head_dim; d1++) {
float ks = kf[d1] * coherence;
for (int64_t d2 = 0; d2 < head_dim; d2++)
Sh[d1 * head_dim + d2] += ks * vh[d2];
}
for (int64_t d = 0; d < head_dim; d++)
zh[d] += kf[d] * coherence;
float den = 1e-8f;
for (int64_t d = 0; d < head_dim; d++)
den += qf[d] * zh[d];
float inv_den = 1.0f / den;
float *ao = attn_inner + t * inner_dim;
for (int64_t d2 = 0; d2 < head_dim; d2++) {
float num = 0.0f;
for (int64_t d1 = 0; d1 < head_dim; d1++)
num += qf[d1] * Sh[d1 * head_dim + d2];
ao[h * head_dim + d2] = num * inv_den;
}
}
if (t > 0 && t % 64 == 0)
hpc_compact_edges(attn_graph);
}
}
if (gate_w && gate_rows > 0) {
int64_t N_out = gate_trans ? n_embd : gate_rows;
float *gated = (float *)malloc(seq_len * N_out * sizeof(float));
if (gated) {
hpc_matmul_graph(attn_inner, gate_w, gated, imp_gate, cnt_gate,
seq_len, inner_dim, N_out, gate_trans);
for (int64_t t = 0; t < seq_len; t++) {
int64_t copy_dim = N_out < n_embd ? N_out : n_embd;
memcpy(attn_out + t * n_embd, gated + t * N_out, copy_dim * sizeof(float));
}
free(gated);
}
} else {
for (int64_t t = 0; t < seq_len; t++) {
int64_t copy_dim = inner_dim < n_embd ? inner_dim : n_embd;
memcpy(attn_out + t * n_embd, attn_inner + t * inner_dim, copy_dim * sizeof(float));
}
}
if (attn_inner) free(attn_inner);
if (attn_graph) hpc_destroy(attn_graph);
free(S); free(z_acc); free(qkv);
} else if (q_w && k_w && v_w && o_w) {
/* ── Separate QKV path (standard transformer) ── */
float *Q = (float *)malloc(seq_len * q_dim * sizeof(float));
float *K_buf = (float *)malloc(seq_len * k_dim * sizeof(float));
float *V_buf = (float *)malloc(seq_len * v_dim * sizeof(float));
if (!Q || !K_buf || !V_buf) {
if(Q) free(Q); if(K_buf) free(K_buf); if(V_buf) free(V_buf);
free(normed); free(attn_out);
return;
}
hpc_matmul_graph(normed, q_w, Q, imp_q, cnt_q, seq_len, n_embd, q_dim, 0);
hpc_matmul_graph(normed, k_w, K_buf, imp_k, cnt_k, seq_len, n_embd, k_dim, 0);
hpc_matmul_graph(normed, v_w, V_buf, imp_v, cnt_v, seq_len, n_embd, v_dim, 0);
int64_t hd_q = q_dim / n_head;
int64_t hd_kv = k_dim / n_head_kv;
int64_t inner_dim = n_head * hd_kv;
HPCGraph *attn_graph = hpc_create(n_head);
float *S = (float *)calloc(n_head * hd_kv * hd_kv, sizeof(float));
float *z_acc = (float *)calloc(n_head * hd_kv, sizeof(float));
float *attn_inner = (float *)calloc(seq_len * inner_dim, sizeof(float));
if (attn_graph && S && z_acc && attn_inner) {
for (int64_t t = 0; t < seq_len; t++) {
for (int64_t h = 0; h < n_head; h++) {
int64_t kv_h = h % n_head_kv;
float *kh = K_buf + t * k_dim + kv_h * hd_kv;
float *vh = V_buf + t * v_dim + kv_h * hd_kv;
float energy = 0.0f;
for (int64_t d = 0; d < hd_kv; d++)
energy += kh[d] * vh[d];
double re[D] = {0}, im[D] = {0};
float ae = fabsf(energy) + 1e-6f;
int ph = ((int)(ae * 100.0f)) % D;
re[ph] = sqrt(ae);
im[ph] = (energy < 0) ? -sqrt(ae)*0.5 : sqrt(ae)*0.5;
hpc_set_local(attn_graph, h, re, im);
}
for (int64_t h = 0; h < n_head - 1; h++)
hpc_cz(attn_graph, h, h+1);
#pragma omp parallel for schedule(static)
for (int64_t h = 0; h < n_head; h++) {
int64_t kv_h = h % n_head_kv;
float *qh = Q + t * q_dim + h * hd_q;
float *kh = K_buf + t * k_dim + kv_h * hd_kv;
float *vh = V_buf + t * v_dim + kv_h * hd_kv;
float *Sh = S + h * hd_kv * hd_kv;
float *zh = z_acc + h * hd_kv;
int64_t feat = hd_q < hd_kv ? hd_q : hd_kv;
float ae = 0.0f;
for(int64_t d=0; d<hd_kv; d++) ae += fabsf(kh[d]*vh[d]);
ae += 1e-6f;
int ph = ((int)(ae * 100.0f)) % D;
double coh_raw = hpc_marginal(attn_graph, h, ph);
float coh = (float)(coh_raw * D);
if (coh < 0.1f) coh = 0.1f;
if (coh > 3.0f) coh = 3.0f;
float *qf = (float *)alloca(feat * sizeof(float));
float *kf = (float *)alloca(feat * sizeof(float));
for (int64_t d = 0; d < feat; d++) {
qf[d] = (qh[d] > 0 ? qh[d] : 0) + 1e-6f;
kf[d] = (kh[d] > 0 ? kh[d] : 0) + 1e-6f;
}
for (int64_t d1 = 0; d1 < feat; d1++) {
float ks = kf[d1] * coh;
for (int64_t d2 = 0; d2 < hd_kv; d2++)
Sh[d1*hd_kv+d2] += ks * vh[d2];
zh[d1] += kf[d1] * coh;
}
float den = 1e-8f;
for (int64_t d = 0; d < feat; d++)
den += qf[d] * zh[d];
float inv_den = 1.0f / den;
float *ao = attn_inner + t * inner_dim;
for (int64_t d2 = 0; d2 < hd_kv; d2++) {
float num = 0.0f;
for (int64_t d1 = 0; d1 < feat; d1++)
num += qf[d1] * Sh[d1*hd_kv+d2];
ao[h*hd_kv+d2] = num * inv_den;
}
}
if (t > 0 && t % 64 == 0)
hpc_compact_edges(attn_graph);
}
}
if (o_w && o_cols > 0) {
float *projected = (float *)calloc(seq_len * n_embd, sizeof(float));
if (projected) {
hpc_matmul_graph(attn_inner, o_w, projected, imp_o, cnt_o,
seq_len, inner_dim, n_embd, 0);
memcpy(attn_out, projected, seq_len * n_embd * sizeof(float));
free(projected);
}
} else {
for (int64_t t = 0; t < seq_len; t++) {
int64_t copy_dim = inner_dim < n_embd ? inner_dim : n_embd;
memcpy(attn_out + t * n_embd, attn_inner + t * inner_dim, copy_dim * sizeof(float));
}
}
if (attn_inner) free(attn_inner);
if (attn_graph) hpc_destroy(attn_graph);
free(S); free(z_acc);
free(Q); free(K_buf); free(V_buf);
}
int64_t total = seq_len * n_embd;
#pragma omp parallel for schedule(static)
for (int64_t i = 0; i < total; i++)
hidden[i] += attn_out[i];
if (ffn_norm_w && ffn_gate_w && ffn_up_w && ffn_down_w && ffn_dim > 0) {
float *normed_ff = (float *)malloc(seq_len * n_embd * sizeof(float));
float *gate_out = (float *)malloc(seq_len * ffn_dim * sizeof(float));
float *up_out = (float *)malloc(seq_len * ffn_dim * sizeof(float));
if (normed_ff && gate_out && up_out) {
hpc_rms_norm(hidden, ffn_norm_w, normed_ff, seq_len, n_embd, eps);
hpc_matmul_graph(normed_ff, ffn_gate_w, gate_out,
imp_ffn_gate, cnt_ffn_gate, seq_len, n_embd, ffn_dim, 0);
hpc_matmul_graph(normed_ff, ffn_up_w, up_out,
imp_ffn_up, cnt_ffn_up, seq_len, n_embd, ffn_dim, 0);
hpc_silu(gate_out, seq_len * ffn_dim);
#pragma omp parallel for schedule(static)
for (int64_t i = 0; i < seq_len * ffn_dim; i++)
gate_out[i] *= up_out[i];
float *ff_out_buf = (float *)malloc(seq_len * n_embd * sizeof(float));
if (ff_out_buf) {
hpc_matmul_graph(gate_out, ffn_down_w, ff_out_buf,
imp_ffn_down, cnt_ffn_down,
seq_len, ffn_dim, n_embd, 0);
#pragma omp parallel for schedule(static)
for (int64_t i = 0; i < total; i++)
hidden[i] += ff_out_buf[i];
free(ff_out_buf);
}
}
free(normed_ff); free(gate_out); free(up_out);
}
free(normed);
free(attn_out);
}