| """Framework-agnostic trit GEMV library. |
| |
| Loads the pre-compiled libtrit_gemv.so via ctypes. |
| Works with PyTorch, JAX, CuPy, or raw CUDA pointers. |
| |
| Compile the library once: |
| cd kernel/ |
| ./build.sh |
| |
| Then use from any framework: |
| from trit_gemv_lib import TritGEMV |
| lib = TritGEMV() |
| |
| # PyTorch |
| lib.gemv_d2(pt_tensor, ws_tensor, xt_tensor, xs_tensor, y_tensor, cols, rows, ng) |
| |
| # Raw pointers (CuPy, JAX, etc.) |
| lib.gemv_d2_ptr(pt_ptr, ws_ptr, xt_ptr, xs_ptr, y_ptr, cols, rows, ng) |
| """ |
| import ctypes |
| import os |
| import subprocess |
| import sys |
|
|
| |
| _LIB_NAMES = ['libtrit_gemv.so', 'libtrit_gemv.dll', 'trit_gemv.so'] |
| _SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) |
|
|
|
|
| def _find_lib(): |
| for name in _LIB_NAMES: |
| path = os.path.join(_SCRIPT_DIR, name) |
| if os.path.exists(path): |
| return path |
| return None |
|
|
|
|
| def _build_lib(): |
| """Auto-compile if not found.""" |
| build_script = os.path.join(_SCRIPT_DIR, 'build.sh') |
| if os.path.exists(build_script): |
| print("Building libtrit_gemv.so...", flush=True) |
| subprocess.run(['bash', build_script], cwd=_SCRIPT_DIR, check=True) |
| else: |
| |
| cu_file = os.path.join(_SCRIPT_DIR, 'trit_gemv_standalone.cu') |
| out_file = os.path.join(_SCRIPT_DIR, 'libtrit_gemv.so') |
| if not os.path.exists(cu_file): |
| raise FileNotFoundError(f"Cannot find {cu_file}") |
|
|
| |
| try: |
| import torch |
| cc = torch.cuda.get_device_capability(0) |
| arch = f"compute_{cc[0]}{cc[1]}" |
| sm = f"sm_{cc[0]}{cc[1]}" |
| gencode = f"-gencode=arch={arch},code={sm}" |
| except: |
| |
| gencode = " ".join([ |
| f"-gencode=arch=compute_{a},code=sm_{a}" |
| for a in ["70", "75", "80", "86", "89", "90"] |
| ]) |
|
|
| cmd = f"nvcc -O3 --use_fast_math -shared -Xcompiler -fPIC {gencode} -o {out_file} {cu_file}" |
| print(f"Compiling: {cmd}", flush=True) |
| subprocess.run(cmd, shell=True, check=True) |
|
|
| return _find_lib() |
|
|
|
|
| class TritGEMV: |
| """Framework-agnostic trit GEMV kernel.""" |
|
|
| def __init__(self, lib_path=None): |
| if lib_path is None: |
| lib_path = _find_lib() |
| if lib_path is None: |
| lib_path = _build_lib() |
| if lib_path is None: |
| raise RuntimeError("Cannot find or build libtrit_gemv.so") |
|
|
| self._lib = ctypes.CDLL(lib_path) |
|
|
| |
| |
| self._lib.trit_gemv_d2_dp4a.argtypes = [ |
| ctypes.c_void_p, |
| ctypes.c_void_p, |
| ctypes.c_void_p, |
| ctypes.c_void_p, |
| ctypes.c_void_p, |
| ctypes.c_int, |
| ctypes.c_int, |
| ctypes.c_int, |
| ctypes.c_int, |
| ] |
| self._lib.trit_gemv_d2_dp4a.restype = None |
|
|
| |
| self._lib.trit_gemv_d3_native.argtypes = [ |
| ctypes.c_void_p, |
| ctypes.c_void_p, |
| ctypes.c_void_p, |
| ctypes.c_void_p, |
| ctypes.c_int, |
| ctypes.c_int, |
| ctypes.c_int, |
| ] |
| self._lib.trit_gemv_d3_native.restype = None |
|
|
| |
| self._lib.trit_gemv_d3_int8_dp4a.argtypes = [ |
| ctypes.c_void_p, |
| ctypes.c_void_p, |
| ctypes.c_void_p, |
| ctypes.c_void_p, |
| ctypes.c_void_p, |
| ctypes.c_int, |
| ctypes.c_int, |
| ctypes.c_int, |
| ctypes.c_int, |
| ] |
| self._lib.trit_gemv_d3_int8_dp4a.restype = None |
|
|
| |
| self._lib.get_l2_cache_bytes.restype = ctypes.c_int |
| self._lib.cuda_sync.restype = None |
|
|
| buf = ctypes.create_string_buffer(256) |
| self._lib.get_gpu_name(buf, 256) |
| self.gpu_name = buf.value.decode() |
| self.l2_bytes = self._lib.get_l2_cache_bytes() |
|
|
| def sync(self): |
| self._lib.cuda_sync() |
|
|
| def _get_ptr(self, tensor): |
| """Extract GPU pointer from any framework's tensor.""" |
| if hasattr(tensor, 'data_ptr'): |
| |
| return tensor.data_ptr() |
| elif hasattr(tensor, '__cuda_array_interface__'): |
| |
| return tensor.__cuda_array_interface__['data'][0] |
| elif isinstance(tensor, int): |
| |
| return tensor |
| else: |
| raise TypeError(f"Cannot extract GPU pointer from {type(tensor)}") |
|
|
| def gemv_d2(self, pt, ws, xt, xs, y, cols, rows, num_groups, l2_persist=True): |
| """D2 GEMV with int4 packing + dp4a. |
| |
| Args: |
| pt: int32 tensor [rows * num_groups * 8] β int4 packed weights |
| ws: float32 tensor [rows * num_groups] β weight scales |
| xt: int32 tensor [num_groups * 16] β int8 packed activations |
| xs: float32 tensor [num_groups] β activation scales |
| y: float32 tensor [rows] β output (written in-place) |
| cols: input dimension (K) |
| rows: output dimension (M) |
| num_groups: K // 64 |
| l2_persist: enable L2 cache persistence (default True) |
| """ |
| self._lib.trit_gemv_d2_dp4a( |
| self._get_ptr(pt), self._get_ptr(ws), |
| self._get_ptr(xt), self._get_ptr(xs), |
| self._get_ptr(y), cols, rows, num_groups, |
| 1 if l2_persist else 0, |
| ) |
|
|
| def gemv_adaptive(self, pt_int4, ws, xt, xs, y, cols, rows, num_groups, |
| pt_int8=None): |
| """Hardware-aware GEMV: auto-selects best kernel based on L2 cache. |
| |
| If the int4 weight data fits in L2 β uses d2 int4 + dp4a (5x FP16) |
| If not β uses pre-expanded int8 + dp4a (2x FP16, no decode overhead) |
| |
| Args: |
| pt_int4: int32 tensor β int4 packed weights (always stored, compact) |
| ws: weight scales |
| xt, xs: quantized activations |
| y: output |
| pt_int8: optional pre-expanded int8 weights for DRAM path. |
| If None and needed, expanded on-the-fly (one-time cost). |
| """ |
| weight_bytes = rows * num_groups * 8 * 4 |
| l2_margin = self.l2_bytes * 0.75 |
|
|
| if weight_bytes < l2_margin: |
| |
| self._lib.trit_gemv_d2_dp4a( |
| self._get_ptr(pt_int4), self._get_ptr(ws), |
| self._get_ptr(xt), self._get_ptr(xs), |
| self._get_ptr(y), cols, rows, num_groups, 1) |
| else: |
| |
| if pt_int8 is None: |
| raise ValueError( |
| f"Layer ({weight_bytes/1e6:.0f} MB) exceeds L2 ({self.l2_bytes/1e6:.0f} MB). " |
| f"Provide pre-expanded pt_int8 for DRAM path. " |
| f"Use TritGEMV.expand_int4_to_int8(pt_int4) at model load time." |
| ) |
| self._lib.trit_gemv_d3_int8_dp4a( |
| self._get_ptr(pt_int8), self._get_ptr(ws), |
| self._get_ptr(xt), self._get_ptr(xs), |
| self._get_ptr(y), cols, rows, num_groups, 0) |
|
|
| @staticmethod |
| def expand_int4_to_int8(pt_int4, device='cuda'): |
| """Pre-expand int4 packed weights to int8 for DRAM-bound layers. |
| |
| Called once at model load. Uses 2x more VRAM but eliminates decode overhead. |
| int4: 8 words per group β int8: 16 words per group |
| |
| Args: |
| pt_int4: int32 tensor [n_groups * 8] β int4 packed |
| Returns: |
| int32 tensor [n_groups * 16] β int8 packed (dp4a compatible) |
| """ |
| import torch |
| n_words = pt_int4.shape[0] |
| n_groups = n_words // 8 |
|
|
| |
| pt_int8 = torch.zeros(n_groups * 16, dtype=torch.int32, device=device) |
|
|
| |
| for g in range(n_groups): |
| for w in range(8): |
| word = pt_int4[g * 8 + w].item() |
| for nib in range(8): |
| val = (word >> (nib * 4)) & 0xF |
| if val & 0x8: |
| val = val | 0xFFFFFFF0 |
| val = val & 0xFF |
| out_col = w * 8 + nib |
| out_word = out_col // 4 |
| out_byte = out_col % 4 |
| pt_int8[g * 16 + out_word] |= (val << (out_byte * 8)) |
|
|
| return pt_int8 |
|
|
| def gemv_d3(self, pt, sc, x, y, cols, rows, depth=3): |
| """D3 GEMV with native trit packing. |
| |
| Args: |
| pt: int32 tensor [rows * ng * 13] β trit packed weights |
| sc: float32 tensor [rows * ng] β scales |
| x: float32 tensor [cols] β activations |
| y: float32 tensor [rows] β output |
| """ |
| self._lib.trit_gemv_d3_native( |
| self._get_ptr(pt), self._get_ptr(sc), |
| self._get_ptr(x), self._get_ptr(y), |
| cols, rows, depth, |
| ) |
|
|
| def gemv_d3_int8(self, wt, ws, xt, xs, y, cols, rows, num_groups, l2_persist=True): |
| """D3 GEMV with int8 level packing + dp4a (same quality as d3, dp4a speed). |
| |
| Args: |
| wt: int32 tensor [rows * num_groups * 16] β int8 packed levels |
| ws: float32 tensor [rows * num_groups] β weight scales |
| xt: int32 tensor [num_groups * 16] β int8 packed activations |
| xs: float32 tensor [num_groups * 16] β per-word x scales |
| y: float32 tensor [rows] β output |
| """ |
| if not hasattr(self._lib, 'trit_gemv_d3_int8_dp4a'): |
| raise RuntimeError("d3 int8 not in this build β rebuild libtrit_gemv.so") |
| self._lib.trit_gemv_d3_int8_dp4a( |
| self._get_ptr(wt), self._get_ptr(ws), |
| self._get_ptr(xt), self._get_ptr(xs), |
| self._get_ptr(y), cols, rows, num_groups, |
| 1 if l2_persist else 0, |
| ) |
|
|
| def __repr__(self): |
| return f"TritGEMV(gpu='{self.gpu_name}', l2={self.l2_bytes/1e6:.0f}MB)" |
|
|