File size: 10,561 Bytes
51e3123 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 | """Framework-agnostic trit GEMV library.
Loads the pre-compiled libtrit_gemv.so via ctypes.
Works with PyTorch, JAX, CuPy, or raw CUDA pointers.
Compile the library once:
cd kernel/
./build.sh
Then use from any framework:
from trit_gemv_lib import TritGEMV
lib = TritGEMV()
# PyTorch
lib.gemv_d2(pt_tensor, ws_tensor, xt_tensor, xs_tensor, y_tensor, cols, rows, ng)
# Raw pointers (CuPy, JAX, etc.)
lib.gemv_d2_ptr(pt_ptr, ws_ptr, xt_ptr, xs_ptr, y_ptr, cols, rows, ng)
"""
import ctypes
import os
import subprocess
import sys
# Find the library
_LIB_NAMES = ['libtrit_gemv.so', 'libtrit_gemv.dll', 'trit_gemv.so']
_SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
def _find_lib():
for name in _LIB_NAMES:
path = os.path.join(_SCRIPT_DIR, name)
if os.path.exists(path):
return path
return None
def _build_lib():
"""Auto-compile if not found."""
build_script = os.path.join(_SCRIPT_DIR, 'build.sh')
if os.path.exists(build_script):
print("Building libtrit_gemv.so...", flush=True)
subprocess.run(['bash', build_script], cwd=_SCRIPT_DIR, check=True)
else:
# Inline build
cu_file = os.path.join(_SCRIPT_DIR, 'trit_gemv_standalone.cu')
out_file = os.path.join(_SCRIPT_DIR, 'libtrit_gemv.so')
if not os.path.exists(cu_file):
raise FileNotFoundError(f"Cannot find {cu_file}")
# Detect GPU architecture
try:
import torch
cc = torch.cuda.get_device_capability(0)
arch = f"compute_{cc[0]}{cc[1]}"
sm = f"sm_{cc[0]}{cc[1]}"
gencode = f"-gencode=arch={arch},code={sm}"
except:
# Default to common architectures
gencode = " ".join([
f"-gencode=arch=compute_{a},code=sm_{a}"
for a in ["70", "75", "80", "86", "89", "90"]
])
cmd = f"nvcc -O3 --use_fast_math -shared -Xcompiler -fPIC {gencode} -o {out_file} {cu_file}"
print(f"Compiling: {cmd}", flush=True)
subprocess.run(cmd, shell=True, check=True)
return _find_lib()
class TritGEMV:
"""Framework-agnostic trit GEMV kernel."""
def __init__(self, lib_path=None):
if lib_path is None:
lib_path = _find_lib()
if lib_path is None:
lib_path = _build_lib()
if lib_path is None:
raise RuntimeError("Cannot find or build libtrit_gemv.so")
self._lib = ctypes.CDLL(lib_path)
# Set up function signatures
# d2 dp4a (champion)
self._lib.trit_gemv_d2_dp4a.argtypes = [
ctypes.c_void_p, # pt (int32*)
ctypes.c_void_p, # ws (float*)
ctypes.c_void_p, # xt (int32*)
ctypes.c_void_p, # xs (float*)
ctypes.c_void_p, # y (float*)
ctypes.c_int, # cols
ctypes.c_int, # rows
ctypes.c_int, # num_groups
ctypes.c_int, # use_l2_persist
]
self._lib.trit_gemv_d2_dp4a.restype = None
# d3 native trit
self._lib.trit_gemv_d3_native.argtypes = [
ctypes.c_void_p, # pt
ctypes.c_void_p, # sc
ctypes.c_void_p, # x
ctypes.c_void_p, # y
ctypes.c_int, # cols
ctypes.c_int, # rows
ctypes.c_int, # depth
]
self._lib.trit_gemv_d3_native.restype = None
# d3 int8 dp4a (no decode, DRAM-bound path)
self._lib.trit_gemv_d3_int8_dp4a.argtypes = [
ctypes.c_void_p, # wt (int32*)
ctypes.c_void_p, # ws (float*)
ctypes.c_void_p, # xt (int32*)
ctypes.c_void_p, # xs (float*)
ctypes.c_void_p, # y (float*)
ctypes.c_int, # cols
ctypes.c_int, # rows
ctypes.c_int, # num_groups
ctypes.c_int, # use_l2_persist
]
self._lib.trit_gemv_d3_int8_dp4a.restype = None
# Utility
self._lib.get_l2_cache_bytes.restype = ctypes.c_int
self._lib.cuda_sync.restype = None
buf = ctypes.create_string_buffer(256)
self._lib.get_gpu_name(buf, 256)
self.gpu_name = buf.value.decode()
self.l2_bytes = self._lib.get_l2_cache_bytes()
def sync(self):
self._lib.cuda_sync()
def _get_ptr(self, tensor):
"""Extract GPU pointer from any framework's tensor."""
if hasattr(tensor, 'data_ptr'):
# PyTorch
return tensor.data_ptr()
elif hasattr(tensor, '__cuda_array_interface__'):
# CuPy, JAX, Numba
return tensor.__cuda_array_interface__['data'][0]
elif isinstance(tensor, int):
# Raw pointer
return tensor
else:
raise TypeError(f"Cannot extract GPU pointer from {type(tensor)}")
def gemv_d2(self, pt, ws, xt, xs, y, cols, rows, num_groups, l2_persist=True):
"""D2 GEMV with int4 packing + dp4a.
Args:
pt: int32 tensor [rows * num_groups * 8] β int4 packed weights
ws: float32 tensor [rows * num_groups] β weight scales
xt: int32 tensor [num_groups * 16] β int8 packed activations
xs: float32 tensor [num_groups] β activation scales
y: float32 tensor [rows] β output (written in-place)
cols: input dimension (K)
rows: output dimension (M)
num_groups: K // 64
l2_persist: enable L2 cache persistence (default True)
"""
self._lib.trit_gemv_d2_dp4a(
self._get_ptr(pt), self._get_ptr(ws),
self._get_ptr(xt), self._get_ptr(xs),
self._get_ptr(y), cols, rows, num_groups,
1 if l2_persist else 0,
)
def gemv_adaptive(self, pt_int4, ws, xt, xs, y, cols, rows, num_groups,
pt_int8=None):
"""Hardware-aware GEMV: auto-selects best kernel based on L2 cache.
If the int4 weight data fits in L2 β uses d2 int4 + dp4a (5x FP16)
If not β uses pre-expanded int8 + dp4a (2x FP16, no decode overhead)
Args:
pt_int4: int32 tensor β int4 packed weights (always stored, compact)
ws: weight scales
xt, xs: quantized activations
y: output
pt_int8: optional pre-expanded int8 weights for DRAM path.
If None and needed, expanded on-the-fly (one-time cost).
"""
weight_bytes = rows * num_groups * 8 * 4 # int4: 8 words per group
l2_margin = self.l2_bytes * 0.75 # leave 25% for x, scales, other data
if weight_bytes < l2_margin:
# Fits in L2 β use compact int4, decode inline at L2 speed
self._lib.trit_gemv_d2_dp4a(
self._get_ptr(pt_int4), self._get_ptr(ws),
self._get_ptr(xt), self._get_ptr(xs),
self._get_ptr(y), cols, rows, num_groups, 1)
else:
# Doesn't fit L2 β use int8 for zero-decode DRAM speed
if pt_int8 is None:
raise ValueError(
f"Layer ({weight_bytes/1e6:.0f} MB) exceeds L2 ({self.l2_bytes/1e6:.0f} MB). "
f"Provide pre-expanded pt_int8 for DRAM path. "
f"Use TritGEMV.expand_int4_to_int8(pt_int4) at model load time."
)
self._lib.trit_gemv_d3_int8_dp4a(
self._get_ptr(pt_int8), self._get_ptr(ws),
self._get_ptr(xt), self._get_ptr(xs),
self._get_ptr(y), cols, rows, num_groups, 0)
@staticmethod
def expand_int4_to_int8(pt_int4, device='cuda'):
"""Pre-expand int4 packed weights to int8 for DRAM-bound layers.
Called once at model load. Uses 2x more VRAM but eliminates decode overhead.
int4: 8 words per group β int8: 16 words per group
Args:
pt_int4: int32 tensor [n_groups * 8] β int4 packed
Returns:
int32 tensor [n_groups * 16] β int8 packed (dp4a compatible)
"""
import torch
n_words = pt_int4.shape[0]
n_groups = n_words // 8
# Each int4 word has 8 nibbles β 8 int8 values β 2 int8x4 words
pt_int8 = torch.zeros(n_groups * 16, dtype=torch.int32, device=device)
# Expand on GPU (vectorized)
for g in range(n_groups):
for w in range(8):
word = pt_int4[g * 8 + w].item()
for nib in range(8):
val = (word >> (nib * 4)) & 0xF
if val & 0x8:
val = val | 0xFFFFFFF0 # sign extend
val = val & 0xFF
out_col = w * 8 + nib
out_word = out_col // 4
out_byte = out_col % 4
pt_int8[g * 16 + out_word] |= (val << (out_byte * 8))
return pt_int8
def gemv_d3(self, pt, sc, x, y, cols, rows, depth=3):
"""D3 GEMV with native trit packing.
Args:
pt: int32 tensor [rows * ng * 13] β trit packed weights
sc: float32 tensor [rows * ng] β scales
x: float32 tensor [cols] β activations
y: float32 tensor [rows] β output
"""
self._lib.trit_gemv_d3_native(
self._get_ptr(pt), self._get_ptr(sc),
self._get_ptr(x), self._get_ptr(y),
cols, rows, depth,
)
def gemv_d3_int8(self, wt, ws, xt, xs, y, cols, rows, num_groups, l2_persist=True):
"""D3 GEMV with int8 level packing + dp4a (same quality as d3, dp4a speed).
Args:
wt: int32 tensor [rows * num_groups * 16] β int8 packed levels
ws: float32 tensor [rows * num_groups] β weight scales
xt: int32 tensor [num_groups * 16] β int8 packed activations
xs: float32 tensor [num_groups * 16] β per-word x scales
y: float32 tensor [rows] β output
"""
if not hasattr(self._lib, 'trit_gemv_d3_int8_dp4a'):
raise RuntimeError("d3 int8 not in this build β rebuild libtrit_gemv.so")
self._lib.trit_gemv_d3_int8_dp4a(
self._get_ptr(wt), self._get_ptr(ws),
self._get_ptr(xt), self._get_ptr(xs),
self._get_ptr(y), cols, rows, num_groups,
1 if l2_persist else 0,
)
def __repr__(self):
return f"TritGEMV(gpu='{self.gpu_name}', l2={self.l2_bytes/1e6:.0f}MB)"
|