CompressedGemma commited on
Commit
63a70a0
Β·
verified Β·
1 Parent(s): 57f4b1d

Q8_0 tied embeddings

Browse files
Files changed (1) hide show
  1. hexstate_quantize.c +352 -0
hexstate_quantize.c CHANGED
@@ -633,6 +633,7 @@ static int is_attention_tensor(const char *gguf_name)
633
  * conservative too" β€” creating coherent precision allocation.
634
  * ═══════════════════════════════════════════════════════════════════════════ */
635
 
 
636
  /* ── Multi-quhit expanded scale table ──
637
  * Search grid: 24Γ—24 = 576 (d, dmin) candidates
638
  * Quhit encoding: bin 24 β†’ 6 for D=6 quhits (BP operates on 6-state marginals)
@@ -2246,6 +2247,345 @@ static void quantize_tensor_q4_0_hpc(const float *weights, int64_t n_elements,
2246
  free(best_candidate);
2247
  }
2248
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2249
  /* Re-derive the 4-bit sub-scale codes (Ls, Lm) for a candidate (d, dmin)
2250
  * pair from the Phase-1 float scales/mins. Bit-identical to the Phase-2b
2251
  * candidate generation, so stored codes are unnecessary. */
@@ -4619,6 +4959,18 @@ void hexstate_quantize_tensor_q4_0_hpc(const float *weights, int64_t n_elements,
4619
  if (out_error) *out_error = err;
4620
  }
4621
 
 
 
 
 
 
 
 
 
 
 
 
 
4622
  #ifndef HEXSTATE_LIBRARY
4623
  /* ═══════════════════════════════════════════════════════════════════════════
4624
  * MAIN
 
633
  * conservative too" β€” creating coherent precision allocation.
634
  * ═══════════════════════════════════════════════════════════════════════════ */
635
 
636
+
637
  /* ── Multi-quhit expanded scale table ──
638
  * Search grid: 24Γ—24 = 576 (d, dmin) candidates
639
  * Quhit encoding: bin 24 β†’ 6 for D=6 quhits (BP operates on 6-state marginals)
 
2247
  free(best_candidate);
2248
  }
2249
 
2250
+ /* ════════════════════════════════════════════════════════════════════════
2251
+ * Q8_0 HPC QUANTIZER β€” Shor pipeline at 8 bits
2252
+ *
2253
+ * Same pipeline as Q4_0: WLS scale + tight candidate grid scored on the
2254
+ * extended objective (weighted SSE + DC + vesica/wave), triality-quhit
2255
+ * graph with Boltzmann-encoded candidate errors, CZ chain entanglement,
2256
+ * Shor Griffiths-Niu sequential measurement for bin consensus, greedy
2257
+ * override (HEX_GREEDY_OVERRIDE_RATIO), then per-block ULP polish, the
2258
+ * vesica/DC error-shaping descent with an extended-objective guard, and
2259
+ * the candidate floor. Intended for embedding / LM-head tensors (tied
2260
+ * embeddings especially), where 2-4 bit codes destroy logit precision.
2261
+ * At 8 bits the candidate grid is tight (Β±1.5%) β€” the win over naive
2262
+ * amax/127 rounding comes from WLS + ULP + spectral selection, not from
2263
+ * coarse scale exploration.
2264
+ * ════════════════════════════════════════════════════════════════════════ */
2265
+
2266
+ #ifndef QK8_0
2267
+ #define QK8_0 32
2268
+ #endif
2269
+ typedef struct { uint16_t d; int8_t qs[QK8_0]; } hex_block_q8_0;
2270
+
2271
+ #define Q8_N_CAND 24
2272
+ static const float Q8_NEIGHBOR_MULTS[Q8_N_CAND] = {
2273
+ 0.9850f, 0.9865f, 0.9880f, 0.9895f, 0.9910f, 0.9925f,
2274
+ 0.9940f, 0.9952f, 0.9964f, 0.9976f, 0.9988f, 1.0000f,
2275
+ 1.0010f, 1.0020f, 1.0030f, 1.0040f, 1.0052f, 1.0064f,
2276
+ 1.0076f, 1.0088f, 1.0100f, 1.0115f, 1.0130f, 1.0150f,
2277
+ };
2278
+ /* 24 candidates β†’ 6 quhit states (4 per bin), same folding as Q4_0 */
2279
+ static const int Q8_CAND_TO_QUHIT[Q8_N_CAND] = {
2280
+ 0,0,0,0, 1,1,1,1, 2,2,2,2, 3,3,3,3, 4,4,4,4, 5,5,5,5
2281
+ };
2282
+
2283
+ static inline float q8_block_ext_err(const float *bw, const float *iw,
2284
+ float d, int8_t *qs_out)
2285
+ {
2286
+ float e_arr[QK8_0];
2287
+ float id = (fabsf(d) > 1e-20f) ? 1.0f / d : 0.0f;
2288
+ float err = 0.0f;
2289
+ for (int j = 0; j < QK8_0; j++) {
2290
+ int q = gguf_nearest_int(bw[j] * id);
2291
+ if (q < -127) q = -127; if (q > 127) q = 127;
2292
+ if (qs_out) qs_out[j] = (int8_t)q;
2293
+ float e = bw[j] - (float)q * d;
2294
+ e_arr[j] = e;
2295
+ float w = iw ? iw[j] : 1.0f;
2296
+ err += e * e * w;
2297
+ }
2298
+ return err + hex_spectral_penalty(e_arr, QK8_0);
2299
+ }
2300
+
2301
+ static void quantize_tensor_q8_0_hpc(const float *weights, int64_t n_elements,
2302
+ hex_block_q8_0 *output,
2303
+ float *out_total_error,
2304
+ const float *imat_importance, int verbose)
2305
+ {
2306
+ int64_t n_blocks = n_elements / QK8_0;
2307
+ float total_err = 0.0f;
2308
+ (void)verbose;
2309
+
2310
+ float (*cand_errors)[Q8_N_CAND] = (float (*)[Q8_N_CAND])
2311
+ calloc(n_blocks, sizeof(float[Q8_N_CAND]));
2312
+ uint16_t (*cand_d16)[Q8_N_CAND] = (uint16_t (*)[Q8_N_CAND])
2313
+ calloc(n_blocks, sizeof(uint16_t[Q8_N_CAND]));
2314
+ int *best_candidate = (int *)malloc(n_blocks * sizeof(int));
2315
+ if (!cand_errors || !cand_d16 || !best_candidate) {
2316
+ free(cand_errors); free(cand_d16); free(best_candidate);
2317
+ if (out_total_error) *out_total_error = -1.0f;
2318
+ return;
2319
+ }
2320
+
2321
+ /* ── Phase 1+2: WLS-refined scale + tight candidate grid ── */
2322
+ #pragma omp parallel for schedule(dynamic, 256)
2323
+ for (int64_t blk = 0; blk < n_blocks; blk++) {
2324
+ const float *bw = weights + blk * QK8_0;
2325
+ const float *iw = imat_importance ? imat_importance + blk * QK8_0 : NULL;
2326
+
2327
+ float amax = 0.0f;
2328
+ for (int j = 0; j < QK8_0; j++) {
2329
+ float av = fabsf(bw[j]);
2330
+ if (av > amax) amax = av;
2331
+ }
2332
+ float wls_d = amax / 127.0f;
2333
+
2334
+ /* ggml-style fixed-point WLS with DC rank-1 augmentation */
2335
+ for (int it = 0; it < 3 && wls_d > 1e-20f; it++) {
2336
+ float inv_d = 1.0f / wls_d;
2337
+ float num = 0.0f, den = 0.0f, dcS = 0.0f, dcQ = 0.0f;
2338
+ for (int j = 0; j < QK8_0; j++) {
2339
+ int q = gguf_nearest_int(bw[j] * inv_d);
2340
+ if (q < -127) q = -127; if (q > 127) q = 127;
2341
+ float qf = (float)q;
2342
+ float w = iw ? iw[j] : 1.0f;
2343
+ num += w * bw[j] * qf;
2344
+ den += w * qf * qf;
2345
+ dcS += bw[j];
2346
+ dcQ += qf;
2347
+ }
2348
+ num += (HEX_DC_LAMBDA / (float)QK8_0) * dcS * dcQ;
2349
+ den += (HEX_DC_LAMBDA / (float)QK8_0) * dcQ * dcQ;
2350
+ if (den > 1e-15f) {
2351
+ float d_new = num / den;
2352
+ if (d_new > 1e-20f) wls_d = d_new;
2353
+ }
2354
+ }
2355
+
2356
+ for (int ci = 0; ci < Q8_N_CAND; ci++) {
2357
+ float trial_d = wls_d * Q8_NEIGHBOR_MULTS[ci];
2358
+ uint16_t d16 = gguf_fp32_to_fp16(trial_d);
2359
+ float actual_d = gguf_fp16_to_fp32(d16);
2360
+ cand_d16 [blk][ci] = d16;
2361
+ cand_errors[blk][ci] = q8_block_ext_err(bw, iw, actual_d, NULL);
2362
+ }
2363
+ best_candidate[blk] = 11; /* Γ—1.0000 neutral seed */
2364
+ }
2365
+
2366
+ /* ── Phase 3: Shor graph β€” triality quhits, CZ chain, GN measurement ── */
2367
+ int shor_ran = 0;
2368
+ if (n_blocks >= 2) {
2369
+ int64_t graph_blocks = (n_blocks > 200) ? 200 : n_blocks;
2370
+ int64_t stride = n_blocks / graph_blocks;
2371
+
2372
+ HPCGraph *graph = hpc_create(graph_blocks);
2373
+ if (graph) {
2374
+ shor_ran = 1;
2375
+
2376
+ /* Adaptive temperature from the candidate-error landscape */
2377
+ float temperature = 1e-10f;
2378
+ {
2379
+ double err_accum = 0.0;
2380
+ int err_count = 0;
2381
+ for (int64_t gi = 0; gi < graph_blocks && gi < 100; gi++) {
2382
+ int64_t blk = gi * stride;
2383
+ float max_e = 0.0f;
2384
+ for (int c = 0; c < Q8_N_CAND; c++)
2385
+ if (cand_errors[blk][c] > max_e)
2386
+ max_e = cand_errors[blk][c];
2387
+ err_accum += (double)max_e;
2388
+ err_count++;
2389
+ }
2390
+ if (err_count > 0) {
2391
+ temperature = (float)(err_accum / err_count) * 0.1f;
2392
+ if (temperature < 1e-10f) temperature = 1e-10f;
2393
+ }
2394
+ }
2395
+
2396
+ /* Boltzmann-encode stride-aggregated candidate errors as
2397
+ * quhit amplitudes (24 candidates folded into 6 states) */
2398
+ for (int64_t i = 0; i < graph_blocks; i++) {
2399
+ float agg_errors[Q8_N_CAND];
2400
+ for (int c = 0; c < Q8_N_CAND; c++) agg_errors[c] = 0.0f;
2401
+ int64_t blk_start = i * stride;
2402
+ int64_t blk_end = blk_start + stride;
2403
+ if (blk_end > n_blocks) blk_end = n_blocks;
2404
+ for (int64_t b = blk_start; b < blk_end; b++)
2405
+ for (int c = 0; c < Q8_N_CAND; c++)
2406
+ agg_errors[c] += cand_errors[b][c];
2407
+ float min_err = 1e30f;
2408
+ for (int c = 0; c < Q8_N_CAND; c++)
2409
+ if (agg_errors[c] < min_err) min_err = agg_errors[c];
2410
+
2411
+ double amp_re[6] = {0,0,0,0,0,0};
2412
+ double amp_norm = 0.0;
2413
+ for (int ci = 0; ci < Q8_N_CAND; ci++)
2414
+ amp_re[Q8_CAND_TO_QUHIT[ci]] +=
2415
+ exp(-(double)(agg_errors[ci] - min_err) /
2416
+ (2.0 * (double)temperature));
2417
+ for (int v = 0; v < 6; v++) amp_norm += amp_re[v] * amp_re[v];
2418
+ if (amp_norm > 1e-30) {
2419
+ double inv = 1.0 / sqrt(amp_norm);
2420
+ for (int v = 0; v < 6; v++) amp_re[v] *= inv;
2421
+ }
2422
+ for (int v = 0; v < 6; v++) {
2423
+ graph->locals[i].edge_re[v] = amp_re[v];
2424
+ graph->locals[i].edge_im[v] = 0.0;
2425
+ }
2426
+ graph->locals[i].primary = VIEW_EDGE;
2427
+ graph->locals[i].dirty = DIRTY_VERTEX | DIRTY_DIAGONAL | DIRTY_FOLDED;
2428
+ graph->locals[i].delta_valid = 0;
2429
+ triality_update_mask(&graph->locals[i]);
2430
+ }
2431
+
2432
+ for (int64_t i = 0; i < graph_blocks - 1; i++)
2433
+ hpc_cz(graph, i, i + 1);
2434
+
2435
+ double (*marg)[6] = (double (*)[6])calloc(graph_blocks, sizeof(double[6]));
2436
+ int *measured = (int *)calloc(graph_blocks, sizeof(int));
2437
+ if (marg && measured) {
2438
+ shor_measure_graph(graph, graph_blocks, marg, measured, 1);
2439
+
2440
+ /* Per-block selection: best candidate inside the Shor-
2441
+ * measured bin, then greedy override against the global
2442
+ * argmin β€” identical Step-F semantics to Q2_K/Q4_0. */
2443
+ for (int64_t i = 0; i < graph_blocks; i++) {
2444
+ int bin = measured[i];
2445
+ if (bin < 0 || bin > 5) {
2446
+ double bm = -1.0; bin = 0;
2447
+ for (int v = 0; v < 6; v++)
2448
+ if (marg[i][v] > bm) { bm = marg[i][v]; bin = v; }
2449
+ }
2450
+ int64_t blk_start = i * stride;
2451
+ int64_t blk_end = blk_start + stride;
2452
+ if (blk_end > n_blocks) blk_end = n_blocks;
2453
+ for (int64_t b = blk_start; b < blk_end; b++) {
2454
+ float bin_best = 1e30f; int bin_cand = -1;
2455
+ float g_best = 1e30f; int g_cand = 0;
2456
+ for (int c = 0; c < Q8_N_CAND; c++) {
2457
+ float e = cand_errors[b][c];
2458
+ if (e < g_best) { g_best = e; g_cand = c; }
2459
+ if (Q8_CAND_TO_QUHIT[c] == bin && e < bin_best) {
2460
+ bin_best = e; bin_cand = c;
2461
+ }
2462
+ }
2463
+ int sel = (bin_cand >= 0) ? bin_cand : g_cand;
2464
+ if (g_best < cand_errors[b][sel] * HEX_GREEDY_OVERRIDE_RATIO)
2465
+ sel = g_cand;
2466
+ best_candidate[b] = sel;
2467
+ }
2468
+ }
2469
+ }
2470
+ free(marg); free(measured);
2471
+ hpc_destroy(graph);
2472
+ }
2473
+ }
2474
+ if (!shor_ran) {
2475
+ for (int64_t blk = 0; blk < n_blocks; blk++) {
2476
+ float g_best = cand_errors[blk][0]; int g_cand = 0;
2477
+ for (int c = 1; c < Q8_N_CAND; c++)
2478
+ if (cand_errors[blk][c] < g_best) {
2479
+ g_best = cand_errors[blk][c]; g_cand = c;
2480
+ }
2481
+ best_candidate[blk] = g_cand;
2482
+ }
2483
+ }
2484
+
2485
+ /* ── Phase 4: ULP polish + vesica/DC shaping guard + floor ── */
2486
+ #pragma omp parallel for schedule(dynamic, 256) reduction(+:total_err)
2487
+ for (int64_t blk = 0; blk < n_blocks; blk++) {
2488
+ const float *bw = weights + blk * QK8_0;
2489
+ const float *iw = imat_importance ? imat_importance + blk * QK8_0 : NULL;
2490
+ int cidx = best_candidate[blk];
2491
+
2492
+ uint16_t best_d16 = cand_d16[blk][cidx];
2493
+ float best_err = cand_errors[blk][cidx];
2494
+
2495
+ /* Β±8 fp16 ULP joint search on the extended objective */
2496
+ for (int du = -8; du <= 8; du++) {
2497
+ if (du == 0) continue;
2498
+ int c16 = (int)cand_d16[blk][cidx] + du;
2499
+ if (c16 <= 0 || c16 > 0x7BFF) continue;
2500
+ float td = gguf_fp16_to_fp32((uint16_t)c16);
2501
+ float err = q8_block_ext_err(bw, iw, td, NULL);
2502
+ if (err < best_err) { best_err = err; best_d16 = (uint16_t)c16; }
2503
+ }
2504
+
2505
+ /* Candidate floor: final ≀ best raw grid candidate (by construction
2506
+ * the ULP search already starts from it, so this is implicit). */
2507
+ float d = gguf_fp16_to_fp32(best_d16);
2508
+ int8_t qs[QK8_0];
2509
+ (void)q8_block_ext_err(bw, iw, d, qs);
2510
+
2511
+ /* Vesica/DC greedy shaping with extended-objective guard */
2512
+ {
2513
+ int8_t qs_shaped[QK8_0];
2514
+ memcpy(qs_shaped, qs, QK8_0);
2515
+ float e_live[QK8_0], v_live[QK8_0 / 2];
2516
+ float vesica_cur = 0.0f, dc_cur = 0.0f;
2517
+ for (int k = 0; k < QK8_0; k++)
2518
+ e_live[k] = bw[k] - (float)qs_shaped[k] * d;
2519
+ for (int p = 0; p < QK8_0 / 2; p++) {
2520
+ v_live[p] = e_live[p] + e_live[p + QK8_0 / 2];
2521
+ vesica_cur += v_live[p] * v_live[p];
2522
+ dc_cur += v_live[p];
2523
+ }
2524
+ float metric_cur = 4.0f * vesica_cur + dc_cur * dc_cur;
2525
+ for (int pass = 0; pass < QK8_0; pass++) {
2526
+ int best_k = -1, best_q_alt = 0;
2527
+ float best_delta = 0.0f;
2528
+ for (int k = 0; k < QK8_0; k++) {
2529
+ int q_try = (e_live[k] >= 0.0f) ? qs_shaped[k] + 1
2530
+ : qs_shaped[k] - 1;
2531
+ if (q_try < -127 || q_try > 127) continue;
2532
+ float e_new = bw[k] - (float)q_try * d;
2533
+ float de = e_new - e_live[k];
2534
+ int pi = (k < QK8_0 / 2) ? k : k - QK8_0 / 2;
2535
+ float v_new = v_live[pi] + de;
2536
+ float ves_a = vesica_cur - v_live[pi] * v_live[pi]
2537
+ + v_new * v_new;
2538
+ float dc_a = dc_cur + de;
2539
+ float delta = metric_cur - (4.0f * ves_a + dc_a * dc_a);
2540
+ if (delta > best_delta) {
2541
+ best_delta = delta; best_k = k; best_q_alt = q_try;
2542
+ }
2543
+ }
2544
+ if (best_k < 0) break;
2545
+ {
2546
+ float e_new = bw[best_k] - (float)best_q_alt * d;
2547
+ float de = e_new - e_live[best_k];
2548
+ int pi = (best_k < QK8_0 / 2) ? best_k
2549
+ : best_k - QK8_0 / 2;
2550
+ float v_new = v_live[pi] + de;
2551
+ vesica_cur += v_new * v_new - v_live[pi] * v_live[pi];
2552
+ dc_cur += de;
2553
+ metric_cur = 4.0f * vesica_cur + dc_cur * dc_cur;
2554
+ v_live[pi] = v_new;
2555
+ e_live[best_k] = e_new;
2556
+ qs_shaped[best_k] = (int8_t)best_q_alt;
2557
+ }
2558
+ }
2559
+ /* Guard on the extended objective vs originals */
2560
+ float e_b[QK8_0], e_s[QK8_0];
2561
+ float err_b = 0.0f, err_s = 0.0f;
2562
+ for (int k = 0; k < QK8_0; k++) {
2563
+ float w = iw ? iw[k] : 1.0f;
2564
+ e_b[k] = bw[k] - (float)qs[k] * d;
2565
+ e_s[k] = bw[k] - (float)qs_shaped[k] * d;
2566
+ err_b += e_b[k] * e_b[k] * w;
2567
+ err_s += e_s[k] * e_s[k] * w;
2568
+ }
2569
+ err_b += hex_spectral_penalty(e_b, QK8_0);
2570
+ err_s += hex_spectral_penalty(e_s, QK8_0);
2571
+ if (err_s < err_b) memcpy(qs, qs_shaped, QK8_0);
2572
+ }
2573
+
2574
+ output[blk].d = best_d16;
2575
+ for (int k = 0; k < QK8_0; k++) {
2576
+ output[blk].qs[k] = qs[k];
2577
+ float e = bw[k] - (float)qs[k] * d;
2578
+ total_err += e * e; /* pure reconstruction SSE report */
2579
+ }
2580
+ }
2581
+
2582
+ free(cand_errors);
2583
+ free(cand_d16);
2584
+ free(best_candidate);
2585
+ if (out_total_error) *out_total_error = total_err;
2586
+ }
2587
+
2588
+
2589
  /* Re-derive the 4-bit sub-scale codes (Ls, Lm) for a candidate (d, dmin)
2590
  * pair from the Phase-1 float scales/mins. Bit-identical to the Phase-2b
2591
  * candidate generation, so stored codes are unnecessary. */
 
4959
  if (out_error) *out_error = err;
4960
  }
4961
 
4962
+ int hexstate_q8_0_block_bytes(void) { return (int)sizeof(hex_block_q8_0); }
4963
+ int hexstate_q8_0_block_elements(void) { return QK8_0; }
4964
+
4965
+ void hexstate_quantize_tensor_q8_0_hpc(const float *weights, int64_t n_elements,
4966
+ void *output, float *out_error,
4967
+ const float *imat_importance, int verbose)
4968
+ {
4969
+ quantize_tensor_q8_0_hpc(weights, n_elements,
4970
+ (hex_block_q8_0 *)output, out_error,
4971
+ imat_importance, verbose);
4972
+ }
4973
+
4974
  #ifndef HEXSTATE_LIBRARY
4975
  /* ═══════════════════════════════════════════════════════════════════════════
4976
  * MAIN