tritllm-kernel / trit_gemv_lib.py
Entrit's picture
initial public release: code, README, KNOWN_ISSUES
51e3123 verified
"""Framework-agnostic trit GEMV library.
Loads the pre-compiled libtrit_gemv.so via ctypes.
Works with PyTorch, JAX, CuPy, or raw CUDA pointers.
Compile the library once:
cd kernel/
./build.sh
Then use from any framework:
from trit_gemv_lib import TritGEMV
lib = TritGEMV()
# PyTorch
lib.gemv_d2(pt_tensor, ws_tensor, xt_tensor, xs_tensor, y_tensor, cols, rows, ng)
# Raw pointers (CuPy, JAX, etc.)
lib.gemv_d2_ptr(pt_ptr, ws_ptr, xt_ptr, xs_ptr, y_ptr, cols, rows, ng)
"""
import ctypes
import os
import subprocess
import sys
# Find the library
_LIB_NAMES = ['libtrit_gemv.so', 'libtrit_gemv.dll', 'trit_gemv.so']
_SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
def _find_lib():
for name in _LIB_NAMES:
path = os.path.join(_SCRIPT_DIR, name)
if os.path.exists(path):
return path
return None
def _build_lib():
"""Auto-compile if not found."""
build_script = os.path.join(_SCRIPT_DIR, 'build.sh')
if os.path.exists(build_script):
print("Building libtrit_gemv.so...", flush=True)
subprocess.run(['bash', build_script], cwd=_SCRIPT_DIR, check=True)
else:
# Inline build
cu_file = os.path.join(_SCRIPT_DIR, 'trit_gemv_standalone.cu')
out_file = os.path.join(_SCRIPT_DIR, 'libtrit_gemv.so')
if not os.path.exists(cu_file):
raise FileNotFoundError(f"Cannot find {cu_file}")
# Detect GPU architecture
try:
import torch
cc = torch.cuda.get_device_capability(0)
arch = f"compute_{cc[0]}{cc[1]}"
sm = f"sm_{cc[0]}{cc[1]}"
gencode = f"-gencode=arch={arch},code={sm}"
except:
# Default to common architectures
gencode = " ".join([
f"-gencode=arch=compute_{a},code=sm_{a}"
for a in ["70", "75", "80", "86", "89", "90"]
])
cmd = f"nvcc -O3 --use_fast_math -shared -Xcompiler -fPIC {gencode} -o {out_file} {cu_file}"
print(f"Compiling: {cmd}", flush=True)
subprocess.run(cmd, shell=True, check=True)
return _find_lib()
class TritGEMV:
"""Framework-agnostic trit GEMV kernel."""
def __init__(self, lib_path=None):
if lib_path is None:
lib_path = _find_lib()
if lib_path is None:
lib_path = _build_lib()
if lib_path is None:
raise RuntimeError("Cannot find or build libtrit_gemv.so")
self._lib = ctypes.CDLL(lib_path)
# Set up function signatures
# d2 dp4a (champion)
self._lib.trit_gemv_d2_dp4a.argtypes = [
ctypes.c_void_p, # pt (int32*)
ctypes.c_void_p, # ws (float*)
ctypes.c_void_p, # xt (int32*)
ctypes.c_void_p, # xs (float*)
ctypes.c_void_p, # y (float*)
ctypes.c_int, # cols
ctypes.c_int, # rows
ctypes.c_int, # num_groups
ctypes.c_int, # use_l2_persist
]
self._lib.trit_gemv_d2_dp4a.restype = None
# d3 native trit
self._lib.trit_gemv_d3_native.argtypes = [
ctypes.c_void_p, # pt
ctypes.c_void_p, # sc
ctypes.c_void_p, # x
ctypes.c_void_p, # y
ctypes.c_int, # cols
ctypes.c_int, # rows
ctypes.c_int, # depth
]
self._lib.trit_gemv_d3_native.restype = None
# d3 int8 dp4a (no decode, DRAM-bound path)
self._lib.trit_gemv_d3_int8_dp4a.argtypes = [
ctypes.c_void_p, # wt (int32*)
ctypes.c_void_p, # ws (float*)
ctypes.c_void_p, # xt (int32*)
ctypes.c_void_p, # xs (float*)
ctypes.c_void_p, # y (float*)
ctypes.c_int, # cols
ctypes.c_int, # rows
ctypes.c_int, # num_groups
ctypes.c_int, # use_l2_persist
]
self._lib.trit_gemv_d3_int8_dp4a.restype = None
# Utility
self._lib.get_l2_cache_bytes.restype = ctypes.c_int
self._lib.cuda_sync.restype = None
buf = ctypes.create_string_buffer(256)
self._lib.get_gpu_name(buf, 256)
self.gpu_name = buf.value.decode()
self.l2_bytes = self._lib.get_l2_cache_bytes()
def sync(self):
self._lib.cuda_sync()
def _get_ptr(self, tensor):
"""Extract GPU pointer from any framework's tensor."""
if hasattr(tensor, 'data_ptr'):
# PyTorch
return tensor.data_ptr()
elif hasattr(tensor, '__cuda_array_interface__'):
# CuPy, JAX, Numba
return tensor.__cuda_array_interface__['data'][0]
elif isinstance(tensor, int):
# Raw pointer
return tensor
else:
raise TypeError(f"Cannot extract GPU pointer from {type(tensor)}")
def gemv_d2(self, pt, ws, xt, xs, y, cols, rows, num_groups, l2_persist=True):
"""D2 GEMV with int4 packing + dp4a.
Args:
pt: int32 tensor [rows * num_groups * 8] β€” int4 packed weights
ws: float32 tensor [rows * num_groups] β€” weight scales
xt: int32 tensor [num_groups * 16] β€” int8 packed activations
xs: float32 tensor [num_groups] β€” activation scales
y: float32 tensor [rows] β€” output (written in-place)
cols: input dimension (K)
rows: output dimension (M)
num_groups: K // 64
l2_persist: enable L2 cache persistence (default True)
"""
self._lib.trit_gemv_d2_dp4a(
self._get_ptr(pt), self._get_ptr(ws),
self._get_ptr(xt), self._get_ptr(xs),
self._get_ptr(y), cols, rows, num_groups,
1 if l2_persist else 0,
)
def gemv_adaptive(self, pt_int4, ws, xt, xs, y, cols, rows, num_groups,
pt_int8=None):
"""Hardware-aware GEMV: auto-selects best kernel based on L2 cache.
If the int4 weight data fits in L2 β†’ uses d2 int4 + dp4a (5x FP16)
If not β†’ uses pre-expanded int8 + dp4a (2x FP16, no decode overhead)
Args:
pt_int4: int32 tensor β€” int4 packed weights (always stored, compact)
ws: weight scales
xt, xs: quantized activations
y: output
pt_int8: optional pre-expanded int8 weights for DRAM path.
If None and needed, expanded on-the-fly (one-time cost).
"""
weight_bytes = rows * num_groups * 8 * 4 # int4: 8 words per group
l2_margin = self.l2_bytes * 0.75 # leave 25% for x, scales, other data
if weight_bytes < l2_margin:
# Fits in L2 β†’ use compact int4, decode inline at L2 speed
self._lib.trit_gemv_d2_dp4a(
self._get_ptr(pt_int4), self._get_ptr(ws),
self._get_ptr(xt), self._get_ptr(xs),
self._get_ptr(y), cols, rows, num_groups, 1)
else:
# Doesn't fit L2 β†’ use int8 for zero-decode DRAM speed
if pt_int8 is None:
raise ValueError(
f"Layer ({weight_bytes/1e6:.0f} MB) exceeds L2 ({self.l2_bytes/1e6:.0f} MB). "
f"Provide pre-expanded pt_int8 for DRAM path. "
f"Use TritGEMV.expand_int4_to_int8(pt_int4) at model load time."
)
self._lib.trit_gemv_d3_int8_dp4a(
self._get_ptr(pt_int8), self._get_ptr(ws),
self._get_ptr(xt), self._get_ptr(xs),
self._get_ptr(y), cols, rows, num_groups, 0)
@staticmethod
def expand_int4_to_int8(pt_int4, device='cuda'):
"""Pre-expand int4 packed weights to int8 for DRAM-bound layers.
Called once at model load. Uses 2x more VRAM but eliminates decode overhead.
int4: 8 words per group β†’ int8: 16 words per group
Args:
pt_int4: int32 tensor [n_groups * 8] β€” int4 packed
Returns:
int32 tensor [n_groups * 16] β€” int8 packed (dp4a compatible)
"""
import torch
n_words = pt_int4.shape[0]
n_groups = n_words // 8
# Each int4 word has 8 nibbles β†’ 8 int8 values β†’ 2 int8x4 words
pt_int8 = torch.zeros(n_groups * 16, dtype=torch.int32, device=device)
# Expand on GPU (vectorized)
for g in range(n_groups):
for w in range(8):
word = pt_int4[g * 8 + w].item()
for nib in range(8):
val = (word >> (nib * 4)) & 0xF
if val & 0x8:
val = val | 0xFFFFFFF0 # sign extend
val = val & 0xFF
out_col = w * 8 + nib
out_word = out_col // 4
out_byte = out_col % 4
pt_int8[g * 16 + out_word] |= (val << (out_byte * 8))
return pt_int8
def gemv_d3(self, pt, sc, x, y, cols, rows, depth=3):
"""D3 GEMV with native trit packing.
Args:
pt: int32 tensor [rows * ng * 13] β€” trit packed weights
sc: float32 tensor [rows * ng] β€” scales
x: float32 tensor [cols] β€” activations
y: float32 tensor [rows] β€” output
"""
self._lib.trit_gemv_d3_native(
self._get_ptr(pt), self._get_ptr(sc),
self._get_ptr(x), self._get_ptr(y),
cols, rows, depth,
)
def gemv_d3_int8(self, wt, ws, xt, xs, y, cols, rows, num_groups, l2_persist=True):
"""D3 GEMV with int8 level packing + dp4a (same quality as d3, dp4a speed).
Args:
wt: int32 tensor [rows * num_groups * 16] β€” int8 packed levels
ws: float32 tensor [rows * num_groups] β€” weight scales
xt: int32 tensor [num_groups * 16] β€” int8 packed activations
xs: float32 tensor [num_groups * 16] β€” per-word x scales
y: float32 tensor [rows] β€” output
"""
if not hasattr(self._lib, 'trit_gemv_d3_int8_dp4a'):
raise RuntimeError("d3 int8 not in this build β€” rebuild libtrit_gemv.so")
self._lib.trit_gemv_d3_int8_dp4a(
self._get_ptr(wt), self._get_ptr(ws),
self._get_ptr(xt), self._get_ptr(xs),
self._get_ptr(y), cols, rows, num_groups,
1 if l2_persist else 0,
)
def __repr__(self):
return f"TritGEMV(gpu='{self.gpu_name}', l2={self.l2_bytes/1e6:.0f}MB)"