--- license: apache-2.0 tags: - cuda - quantization - ternary - llm-inference - kernel --- # tritllm-kernel Multiply-free ternary GEMV CUDA kernel for the codec from **"Balanced Ternary Post-Training Quantization for Large Language Models"** (Stentzel, 2026). The headline number from the paper: **7.8× speedup** over cuBLAS FP16 GEMV on RTX 4090 in the memory-bound regime, projected to full-model token generation throughput from per-layer benchmarks. > **These are kernel-only projections, not end-to-end serving throughput.** They exclude attention, KV cache, sampling, and tokenizer overhead. See Section 7 of the paper for methodology. ## What it is A standalone CUDA shared library (`libtrit_gemv.so` / `.dll`) callable via `ctypes` from any language, with no PyTorch dependency. The same algorithm is also wrapped via PyTorch's pybind11 in `trit_gemv_wrapper.cu` for benchmarking. The core trick: each ternary weight (-1, 0, +1) reduces a multiply-accumulate to a conditional add/subtract/skip. The kernel uses Ada/Hopper/Blackwell `dp4a` intrinsics on int4-packed weights and pre-interleaved int8 activations to do four ternary-times-int8 dot products per instruction. ## Build ```bash cd kernel ./build.sh ``` The build script targets SM 70/75/80/86/89/90/100/120 in one fat binary so the `.so` runs on V100, T4, A100, RTX 30/40/50, H100, and B100/B200 without recompilation. Required: `nvcc` (CUDA 11.8 or newer) and a C++ compiler. ## Performance (Qwen2.5-7B, d=2, 3.47 bpw, 3.3 GB model) | GPU | L2 cache | Tokens/sec | Speedup vs FP16 cuBLAS | Effective BW | |---|---|---|---|---| | RTX 4090 | 72 MB | 588 | 7.8× | 1940 GB/s | | RTX 3090 | 6 MB | 192 | 3.4× | 633 GB/s | | RTX 4080 Laptop | 64 MB | 133 | 5.8× | 439 GB/s | | A100 80GB | 40 MB | 201 | 4.2× | 663 GB/s | These are per-layer GEMV benchmarks projected to full-model token-generation throughput. The L2-cache size correlates strongly with speedup because each `d=2` layer fits in L2 on the RTX 4090, giving an effective bandwidth roughly 2× HBM bandwidth. See `kernel/bench_*.py` for the benchmark drivers. ## Launch contract The kernels in `trit_gemv.cu` and `trit_gemv_standalone.cu` assume: | Constraint | Why | What happens if violated | |---|---|---| | `blockDim.x == 32` (one warp per block) | Kernels use `__shfl_down_sync(0xFFFFFFFF, ...)` and lane-0 reduction | OOB index reads + race on `y[row]` | | `in_features % 64 == 0` | Group size is fixed at 64 weights | Trailing partial group is silently dropped — incorrect output for that row | | Weight, scale, and activation buffers are device-resident and properly aligned | Kernel uses `__ldg` for cached loads | UB / device fault | If your model has `in_features` not divisible by 64, pad the weight matrix to the next multiple of 64 with zero rows before quantizing. ## API surface C ABI in `trit_gemv_standalone.cu`: ```c // Best-tested d=2 path (champion for 4090) void trit_gemv_d2_fast( const int32_t* pt, // [rows * num_groups * 8] int4-packed weights const float* ws, // [rows * num_groups] scales const int32_t* xt_e, // [num_groups * 8] even nibble activations const int32_t* xt_o, // [num_groups * 8] odd nibble activations const float* xs, // [num_groups] activation scales float* y, // [rows] output int cols, int rows, int num_groups, int use_l2_persist // 0 = off, 1 = enable L2 persistence ); // Native-trit packed d=3 (no int4 intermediate) void trit_gemv_d3_native( const int32_t* pt, // [rows * num_groups * 13] trit-packed const float* sc, const float* x, float* y, int cols, int rows, int depth ); // L2 cache size query (for deciding whether to enable persist) int get_l2_cache_bytes(); void get_gpu_name(char* buf, int buflen); void cuda_sync(); ``` ## Error reporting All `extern "C"` entry points return `void`, so per-call status is delivered through a separate channel: ```c int trit_gemv_get_last_error(); ``` Returns `0` on success. Negative values are host-side argument-validation failures (`TRIT_ERR_NULL_PTR`, `TRIT_ERR_BAD_DIM`, `TRIT_ERR_BAD_GROUP`, `TRIT_ERR_BAD_BUFFER`). Positive values are `cudaError_t` codes captured from the most recent kernel launch. The host-side validator in each entry point checks pointer non-null, positive dimensions, `cols % 64 == 0`, and `cols / 64 == num_groups`. If validation fails, no kernel is launched, the error is recorded, and the call returns silently. ## Known limitations The educational kernels in `trit_gemv.cu` use a lane-0 scale-and-add reduction that idles 31 lanes per group. This is a deliberate readability tradeoff — the headline 7.8× number is from the deferred-reduction `k_d3_hardened` kernel in `trit_gemv_standalone.cu`. See [KNOWN_ISSUES.md](KNOWN_ISSUES.md) for details and a planned API-cleanup item. ## Citation ``` @article{stentzel2026ternaryptq, title = {Balanced Ternary Post-Training Quantization for Large Language Models}, author = {Stentzel, Eric}, year = 2026, note = {Entrit Systems} } ``` ## License Apache-2.0.