"""Framework-agnostic trit GEMV library. Loads the pre-compiled libtrit_gemv.so via ctypes. Works with PyTorch, JAX, CuPy, or raw CUDA pointers. Compile the library once: cd kernel/ ./build.sh Then use from any framework: from trit_gemv_lib import TritGEMV lib = TritGEMV() # PyTorch lib.gemv_d2(pt_tensor, ws_tensor, xt_tensor, xs_tensor, y_tensor, cols, rows, ng) # Raw pointers (CuPy, JAX, etc.) lib.gemv_d2_ptr(pt_ptr, ws_ptr, xt_ptr, xs_ptr, y_ptr, cols, rows, ng) """ import ctypes import os import subprocess import sys # Find the library _LIB_NAMES = ['libtrit_gemv.so', 'libtrit_gemv.dll', 'trit_gemv.so'] _SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) def _find_lib(): for name in _LIB_NAMES: path = os.path.join(_SCRIPT_DIR, name) if os.path.exists(path): return path return None def _build_lib(): """Auto-compile if not found.""" build_script = os.path.join(_SCRIPT_DIR, 'build.sh') if os.path.exists(build_script): print("Building libtrit_gemv.so...", flush=True) subprocess.run(['bash', build_script], cwd=_SCRIPT_DIR, check=True) else: # Inline build cu_file = os.path.join(_SCRIPT_DIR, 'trit_gemv_standalone.cu') out_file = os.path.join(_SCRIPT_DIR, 'libtrit_gemv.so') if not os.path.exists(cu_file): raise FileNotFoundError(f"Cannot find {cu_file}") # Detect GPU architecture try: import torch cc = torch.cuda.get_device_capability(0) arch = f"compute_{cc[0]}{cc[1]}" sm = f"sm_{cc[0]}{cc[1]}" gencode = f"-gencode=arch={arch},code={sm}" except: # Default to common architectures gencode = " ".join([ f"-gencode=arch=compute_{a},code=sm_{a}" for a in ["70", "75", "80", "86", "89", "90"] ]) cmd = f"nvcc -O3 --use_fast_math -shared -Xcompiler -fPIC {gencode} -o {out_file} {cu_file}" print(f"Compiling: {cmd}", flush=True) subprocess.run(cmd, shell=True, check=True) return _find_lib() class TritGEMV: """Framework-agnostic trit GEMV kernel.""" def __init__(self, lib_path=None): if lib_path is None: lib_path = _find_lib() if lib_path is None: lib_path = _build_lib() if lib_path is None: raise RuntimeError("Cannot find or build libtrit_gemv.so") self._lib = ctypes.CDLL(lib_path) # Set up function signatures # d2 dp4a (champion) self._lib.trit_gemv_d2_dp4a.argtypes = [ ctypes.c_void_p, # pt (int32*) ctypes.c_void_p, # ws (float*) ctypes.c_void_p, # xt (int32*) ctypes.c_void_p, # xs (float*) ctypes.c_void_p, # y (float*) ctypes.c_int, # cols ctypes.c_int, # rows ctypes.c_int, # num_groups ctypes.c_int, # use_l2_persist ] self._lib.trit_gemv_d2_dp4a.restype = None # d3 native trit self._lib.trit_gemv_d3_native.argtypes = [ ctypes.c_void_p, # pt ctypes.c_void_p, # sc ctypes.c_void_p, # x ctypes.c_void_p, # y ctypes.c_int, # cols ctypes.c_int, # rows ctypes.c_int, # depth ] self._lib.trit_gemv_d3_native.restype = None # d3 int8 dp4a (no decode, DRAM-bound path) self._lib.trit_gemv_d3_int8_dp4a.argtypes = [ ctypes.c_void_p, # wt (int32*) ctypes.c_void_p, # ws (float*) ctypes.c_void_p, # xt (int32*) ctypes.c_void_p, # xs (float*) ctypes.c_void_p, # y (float*) ctypes.c_int, # cols ctypes.c_int, # rows ctypes.c_int, # num_groups ctypes.c_int, # use_l2_persist ] self._lib.trit_gemv_d3_int8_dp4a.restype = None # Utility self._lib.get_l2_cache_bytes.restype = ctypes.c_int self._lib.cuda_sync.restype = None buf = ctypes.create_string_buffer(256) self._lib.get_gpu_name(buf, 256) self.gpu_name = buf.value.decode() self.l2_bytes = self._lib.get_l2_cache_bytes() def sync(self): self._lib.cuda_sync() def _get_ptr(self, tensor): """Extract GPU pointer from any framework's tensor.""" if hasattr(tensor, 'data_ptr'): # PyTorch return tensor.data_ptr() elif hasattr(tensor, '__cuda_array_interface__'): # CuPy, JAX, Numba return tensor.__cuda_array_interface__['data'][0] elif isinstance(tensor, int): # Raw pointer return tensor else: raise TypeError(f"Cannot extract GPU pointer from {type(tensor)}") def gemv_d2(self, pt, ws, xt, xs, y, cols, rows, num_groups, l2_persist=True): """D2 GEMV with int4 packing + dp4a. Args: pt: int32 tensor [rows * num_groups * 8] — int4 packed weights ws: float32 tensor [rows * num_groups] — weight scales xt: int32 tensor [num_groups * 16] — int8 packed activations xs: float32 tensor [num_groups] — activation scales y: float32 tensor [rows] — output (written in-place) cols: input dimension (K) rows: output dimension (M) num_groups: K // 64 l2_persist: enable L2 cache persistence (default True) """ self._lib.trit_gemv_d2_dp4a( self._get_ptr(pt), self._get_ptr(ws), self._get_ptr(xt), self._get_ptr(xs), self._get_ptr(y), cols, rows, num_groups, 1 if l2_persist else 0, ) def gemv_adaptive(self, pt_int4, ws, xt, xs, y, cols, rows, num_groups, pt_int8=None): """Hardware-aware GEMV: auto-selects best kernel based on L2 cache. If the int4 weight data fits in L2 → uses d2 int4 + dp4a (5x FP16) If not → uses pre-expanded int8 + dp4a (2x FP16, no decode overhead) Args: pt_int4: int32 tensor — int4 packed weights (always stored, compact) ws: weight scales xt, xs: quantized activations y: output pt_int8: optional pre-expanded int8 weights for DRAM path. If None and needed, expanded on-the-fly (one-time cost). """ weight_bytes = rows * num_groups * 8 * 4 # int4: 8 words per group l2_margin = self.l2_bytes * 0.75 # leave 25% for x, scales, other data if weight_bytes < l2_margin: # Fits in L2 → use compact int4, decode inline at L2 speed self._lib.trit_gemv_d2_dp4a( self._get_ptr(pt_int4), self._get_ptr(ws), self._get_ptr(xt), self._get_ptr(xs), self._get_ptr(y), cols, rows, num_groups, 1) else: # Doesn't fit L2 → use int8 for zero-decode DRAM speed if pt_int8 is None: raise ValueError( f"Layer ({weight_bytes/1e6:.0f} MB) exceeds L2 ({self.l2_bytes/1e6:.0f} MB). " f"Provide pre-expanded pt_int8 for DRAM path. " f"Use TritGEMV.expand_int4_to_int8(pt_int4) at model load time." ) self._lib.trit_gemv_d3_int8_dp4a( self._get_ptr(pt_int8), self._get_ptr(ws), self._get_ptr(xt), self._get_ptr(xs), self._get_ptr(y), cols, rows, num_groups, 0) @staticmethod def expand_int4_to_int8(pt_int4, device='cuda'): """Pre-expand int4 packed weights to int8 for DRAM-bound layers. Called once at model load. Uses 2x more VRAM but eliminates decode overhead. int4: 8 words per group → int8: 16 words per group Args: pt_int4: int32 tensor [n_groups * 8] — int4 packed Returns: int32 tensor [n_groups * 16] — int8 packed (dp4a compatible) """ import torch n_words = pt_int4.shape[0] n_groups = n_words // 8 # Each int4 word has 8 nibbles → 8 int8 values → 2 int8x4 words pt_int8 = torch.zeros(n_groups * 16, dtype=torch.int32, device=device) # Expand on GPU (vectorized) for g in range(n_groups): for w in range(8): word = pt_int4[g * 8 + w].item() for nib in range(8): val = (word >> (nib * 4)) & 0xF if val & 0x8: val = val | 0xFFFFFFF0 # sign extend val = val & 0xFF out_col = w * 8 + nib out_word = out_col // 4 out_byte = out_col % 4 pt_int8[g * 16 + out_word] |= (val << (out_byte * 8)) return pt_int8 def gemv_d3(self, pt, sc, x, y, cols, rows, depth=3): """D3 GEMV with native trit packing. Args: pt: int32 tensor [rows * ng * 13] — trit packed weights sc: float32 tensor [rows * ng] — scales x: float32 tensor [cols] — activations y: float32 tensor [rows] — output """ self._lib.trit_gemv_d3_native( self._get_ptr(pt), self._get_ptr(sc), self._get_ptr(x), self._get_ptr(y), cols, rows, depth, ) def gemv_d3_int8(self, wt, ws, xt, xs, y, cols, rows, num_groups, l2_persist=True): """D3 GEMV with int8 level packing + dp4a (same quality as d3, dp4a speed). Args: wt: int32 tensor [rows * num_groups * 16] — int8 packed levels ws: float32 tensor [rows * num_groups] — weight scales xt: int32 tensor [num_groups * 16] — int8 packed activations xs: float32 tensor [num_groups * 16] — per-word x scales y: float32 tensor [rows] — output """ if not hasattr(self._lib, 'trit_gemv_d3_int8_dp4a'): raise RuntimeError("d3 int8 not in this build — rebuild libtrit_gemv.so") self._lib.trit_gemv_d3_int8_dp4a( self._get_ptr(wt), self._get_ptr(ws), self._get_ptr(xt), self._get_ptr(xs), self._get_ptr(y), cols, rows, num_groups, 1 if l2_persist else 0, ) def __repr__(self): return f"TritGEMV(gpu='{self.gpu_name}', l2={self.l2_bytes/1e6:.0f}MB)"