| --- |
| license: apache-2.0 |
| tags: |
| - cuda |
| - quantization |
| - ternary |
| - llm-inference |
| - kernel |
| --- |
| |
| # tritllm-kernel |
|
|
| Multiply-free ternary GEMV CUDA kernel for the codec from |
| **"Balanced Ternary Post-Training Quantization for Large Language Models"** (Stentzel, 2026). |
|
|
| The headline number from the paper: **7.8× speedup** over cuBLAS FP16 GEMV on RTX 4090 in the memory-bound regime, projected to full-model token generation throughput from per-layer benchmarks. |
|
|
| > **These are kernel-only projections, not end-to-end serving throughput.** They exclude attention, KV cache, sampling, and tokenizer overhead. See Section 7 of the paper for methodology. |
|
|
| ## What it is |
|
|
| A standalone CUDA shared library (`libtrit_gemv.so` / `.dll`) callable via `ctypes` from any language, with no PyTorch dependency. The same algorithm is also wrapped via PyTorch's pybind11 in `trit_gemv_wrapper.cu` for benchmarking. |
|
|
| The core trick: each ternary weight (-1, 0, +1) reduces a multiply-accumulate to a conditional add/subtract/skip. The kernel uses Ada/Hopper/Blackwell `dp4a` intrinsics on int4-packed weights and pre-interleaved int8 activations to do four ternary-times-int8 dot products per instruction. |
|
|
| ## Build |
|
|
| ```bash |
| cd kernel |
| ./build.sh |
| ``` |
|
|
| The build script targets SM 70/75/80/86/89/90/100/120 in one fat binary so the `.so` runs on V100, T4, A100, RTX 30/40/50, H100, and B100/B200 without recompilation. |
|
|
| Required: `nvcc` (CUDA 11.8 or newer) and a C++ compiler. |
|
|
| ## Performance (Qwen2.5-7B, d=2, 3.47 bpw, 3.3 GB model) |
|
|
| | GPU | L2 cache | Tokens/sec | Speedup vs FP16 cuBLAS | Effective BW | |
| |---|---|---|---|---| |
| | RTX 4090 | 72 MB | 588 | 7.8× | 1940 GB/s | |
| | RTX 3090 | 6 MB | 192 | 3.4× | 633 GB/s | |
| | RTX 4080 Laptop | 64 MB | 133 | 5.8× | 439 GB/s | |
| | A100 80GB | 40 MB | 201 | 4.2× | 663 GB/s | |
|
|
| These are per-layer GEMV benchmarks projected to full-model token-generation throughput. The L2-cache size correlates strongly with speedup because each `d=2` layer fits in L2 on the RTX 4090, giving an effective bandwidth roughly 2× HBM bandwidth. |
|
|
| See `kernel/bench_*.py` for the benchmark drivers. |
|
|
| ## Launch contract |
|
|
| The kernels in `trit_gemv.cu` and `trit_gemv_standalone.cu` assume: |
|
|
| | Constraint | Why | What happens if violated | |
| |---|---|---| |
| | `blockDim.x == 32` (one warp per block) | Kernels use `__shfl_down_sync(0xFFFFFFFF, ...)` and lane-0 reduction | OOB index reads + race on `y[row]` | |
| | `in_features % 64 == 0` | Group size is fixed at 64 weights | Trailing partial group is silently dropped — incorrect output for that row | |
| | Weight, scale, and activation buffers are device-resident and properly aligned | Kernel uses `__ldg` for cached loads | UB / device fault | |
|
|
| If your model has `in_features` not divisible by 64, pad the weight matrix to the next multiple of 64 with zero rows before quantizing. |
|
|
| ## API surface |
|
|
| C ABI in `trit_gemv_standalone.cu`: |
|
|
| ```c |
| // Best-tested d=2 path (champion for 4090) |
| void trit_gemv_d2_fast( |
| const int32_t* pt, // [rows * num_groups * 8] int4-packed weights |
| const float* ws, // [rows * num_groups] scales |
| const int32_t* xt_e, // [num_groups * 8] even nibble activations |
| const int32_t* xt_o, // [num_groups * 8] odd nibble activations |
| const float* xs, // [num_groups] activation scales |
| float* y, // [rows] output |
| int cols, int rows, int num_groups, |
| int use_l2_persist // 0 = off, 1 = enable L2 persistence |
| ); |
| |
| // Native-trit packed d=3 (no int4 intermediate) |
| void trit_gemv_d3_native( |
| const int32_t* pt, // [rows * num_groups * 13] trit-packed |
| const float* sc, |
| const float* x, |
| float* y, |
| int cols, int rows, int depth |
| ); |
| |
| // L2 cache size query (for deciding whether to enable persist) |
| int get_l2_cache_bytes(); |
| void get_gpu_name(char* buf, int buflen); |
| void cuda_sync(); |
| ``` |
|
|
| ## Error reporting |
|
|
| All `extern "C"` entry points return `void`, so per-call status is delivered through a separate channel: |
|
|
| ```c |
| int trit_gemv_get_last_error(); |
| ``` |
|
|
| Returns `0` on success. Negative values are host-side argument-validation failures (`TRIT_ERR_NULL_PTR`, `TRIT_ERR_BAD_DIM`, `TRIT_ERR_BAD_GROUP`, `TRIT_ERR_BAD_BUFFER`). Positive values are `cudaError_t` codes captured from the most recent kernel launch. |
|
|
| The host-side validator in each entry point checks pointer non-null, positive dimensions, `cols % 64 == 0`, and `cols / 64 == num_groups`. If validation fails, no kernel is launched, the error is recorded, and the call returns silently. |
|
|
| ## Known limitations |
|
|
| The educational kernels in `trit_gemv.cu` use a lane-0 scale-and-add reduction that idles 31 lanes per group. This is a deliberate readability tradeoff — the headline 7.8× number is from the deferred-reduction `k_d3_hardened` kernel in `trit_gemv_standalone.cu`. See [KNOWN_ISSUES.md](KNOWN_ISSUES.md) for details and a planned API-cleanup item. |
|
|
| ## Citation |
|
|
| ``` |
| @article{stentzel2026ternaryptq, |
| title = {Balanced Ternary Post-Training Quantization for Large Language Models}, |
| author = {Stentzel, Eric}, |
| year = 2026, |
| note = {Entrit Systems} |
| } |
| ``` |
|
|
| ## License |
|
|
| Apache-2.0. |
|
|