File size: 10,561 Bytes
51e3123
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
"""Framework-agnostic trit GEMV library.

Loads the pre-compiled libtrit_gemv.so via ctypes.
Works with PyTorch, JAX, CuPy, or raw CUDA pointers.

Compile the library once:
    cd kernel/
    ./build.sh

Then use from any framework:
    from trit_gemv_lib import TritGEMV
    lib = TritGEMV()

    # PyTorch
    lib.gemv_d2(pt_tensor, ws_tensor, xt_tensor, xs_tensor, y_tensor, cols, rows, ng)

    # Raw pointers (CuPy, JAX, etc.)
    lib.gemv_d2_ptr(pt_ptr, ws_ptr, xt_ptr, xs_ptr, y_ptr, cols, rows, ng)
"""
import ctypes
import os
import subprocess
import sys

# Find the library
_LIB_NAMES = ['libtrit_gemv.so', 'libtrit_gemv.dll', 'trit_gemv.so']
_SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))


def _find_lib():
    for name in _LIB_NAMES:
        path = os.path.join(_SCRIPT_DIR, name)
        if os.path.exists(path):
            return path
    return None


def _build_lib():
    """Auto-compile if not found."""
    build_script = os.path.join(_SCRIPT_DIR, 'build.sh')
    if os.path.exists(build_script):
        print("Building libtrit_gemv.so...", flush=True)
        subprocess.run(['bash', build_script], cwd=_SCRIPT_DIR, check=True)
    else:
        # Inline build
        cu_file = os.path.join(_SCRIPT_DIR, 'trit_gemv_standalone.cu')
        out_file = os.path.join(_SCRIPT_DIR, 'libtrit_gemv.so')
        if not os.path.exists(cu_file):
            raise FileNotFoundError(f"Cannot find {cu_file}")

        # Detect GPU architecture
        try:
            import torch
            cc = torch.cuda.get_device_capability(0)
            arch = f"compute_{cc[0]}{cc[1]}"
            sm = f"sm_{cc[0]}{cc[1]}"
            gencode = f"-gencode=arch={arch},code={sm}"
        except:
            # Default to common architectures
            gencode = " ".join([
                f"-gencode=arch=compute_{a},code=sm_{a}"
                for a in ["70", "75", "80", "86", "89", "90"]
            ])

        cmd = f"nvcc -O3 --use_fast_math -shared -Xcompiler -fPIC {gencode} -o {out_file} {cu_file}"
        print(f"Compiling: {cmd}", flush=True)
        subprocess.run(cmd, shell=True, check=True)

    return _find_lib()


class TritGEMV:
    """Framework-agnostic trit GEMV kernel."""

    def __init__(self, lib_path=None):
        if lib_path is None:
            lib_path = _find_lib()
        if lib_path is None:
            lib_path = _build_lib()
        if lib_path is None:
            raise RuntimeError("Cannot find or build libtrit_gemv.so")

        self._lib = ctypes.CDLL(lib_path)

        # Set up function signatures
        # d2 dp4a (champion)
        self._lib.trit_gemv_d2_dp4a.argtypes = [
            ctypes.c_void_p,  # pt (int32*)
            ctypes.c_void_p,  # ws (float*)
            ctypes.c_void_p,  # xt (int32*)
            ctypes.c_void_p,  # xs (float*)
            ctypes.c_void_p,  # y (float*)
            ctypes.c_int,     # cols
            ctypes.c_int,     # rows
            ctypes.c_int,     # num_groups
            ctypes.c_int,     # use_l2_persist
        ]
        self._lib.trit_gemv_d2_dp4a.restype = None

        # d3 native trit
        self._lib.trit_gemv_d3_native.argtypes = [
            ctypes.c_void_p,  # pt
            ctypes.c_void_p,  # sc
            ctypes.c_void_p,  # x
            ctypes.c_void_p,  # y
            ctypes.c_int,     # cols
            ctypes.c_int,     # rows
            ctypes.c_int,     # depth
        ]
        self._lib.trit_gemv_d3_native.restype = None

        # d3 int8 dp4a (no decode, DRAM-bound path)
        self._lib.trit_gemv_d3_int8_dp4a.argtypes = [
            ctypes.c_void_p,  # wt (int32*)
            ctypes.c_void_p,  # ws (float*)
            ctypes.c_void_p,  # xt (int32*)
            ctypes.c_void_p,  # xs (float*)
            ctypes.c_void_p,  # y (float*)
            ctypes.c_int,     # cols
            ctypes.c_int,     # rows
            ctypes.c_int,     # num_groups
            ctypes.c_int,     # use_l2_persist
        ]
        self._lib.trit_gemv_d3_int8_dp4a.restype = None

        # Utility
        self._lib.get_l2_cache_bytes.restype = ctypes.c_int
        self._lib.cuda_sync.restype = None

        buf = ctypes.create_string_buffer(256)
        self._lib.get_gpu_name(buf, 256)
        self.gpu_name = buf.value.decode()
        self.l2_bytes = self._lib.get_l2_cache_bytes()

    def sync(self):
        self._lib.cuda_sync()

    def _get_ptr(self, tensor):
        """Extract GPU pointer from any framework's tensor."""
        if hasattr(tensor, 'data_ptr'):
            # PyTorch
            return tensor.data_ptr()
        elif hasattr(tensor, '__cuda_array_interface__'):
            # CuPy, JAX, Numba
            return tensor.__cuda_array_interface__['data'][0]
        elif isinstance(tensor, int):
            # Raw pointer
            return tensor
        else:
            raise TypeError(f"Cannot extract GPU pointer from {type(tensor)}")

    def gemv_d2(self, pt, ws, xt, xs, y, cols, rows, num_groups, l2_persist=True):
        """D2 GEMV with int4 packing + dp4a.

        Args:
            pt: int32 tensor [rows * num_groups * 8] β€” int4 packed weights
            ws: float32 tensor [rows * num_groups] β€” weight scales
            xt: int32 tensor [num_groups * 16] β€” int8 packed activations
            xs: float32 tensor [num_groups] β€” activation scales
            y:  float32 tensor [rows] β€” output (written in-place)
            cols: input dimension (K)
            rows: output dimension (M)
            num_groups: K // 64
            l2_persist: enable L2 cache persistence (default True)
        """
        self._lib.trit_gemv_d2_dp4a(
            self._get_ptr(pt), self._get_ptr(ws),
            self._get_ptr(xt), self._get_ptr(xs),
            self._get_ptr(y), cols, rows, num_groups,
            1 if l2_persist else 0,
        )

    def gemv_adaptive(self, pt_int4, ws, xt, xs, y, cols, rows, num_groups,
                      pt_int8=None):
        """Hardware-aware GEMV: auto-selects best kernel based on L2 cache.

        If the int4 weight data fits in L2 β†’ uses d2 int4 + dp4a (5x FP16)
        If not β†’ uses pre-expanded int8 + dp4a (2x FP16, no decode overhead)

        Args:
            pt_int4: int32 tensor β€” int4 packed weights (always stored, compact)
            ws: weight scales
            xt, xs: quantized activations
            y: output
            pt_int8: optional pre-expanded int8 weights for DRAM path.
                     If None and needed, expanded on-the-fly (one-time cost).
        """
        weight_bytes = rows * num_groups * 8 * 4  # int4: 8 words per group
        l2_margin = self.l2_bytes * 0.75  # leave 25% for x, scales, other data

        if weight_bytes < l2_margin:
            # Fits in L2 β†’ use compact int4, decode inline at L2 speed
            self._lib.trit_gemv_d2_dp4a(
                self._get_ptr(pt_int4), self._get_ptr(ws),
                self._get_ptr(xt), self._get_ptr(xs),
                self._get_ptr(y), cols, rows, num_groups, 1)
        else:
            # Doesn't fit L2 β†’ use int8 for zero-decode DRAM speed
            if pt_int8 is None:
                raise ValueError(
                    f"Layer ({weight_bytes/1e6:.0f} MB) exceeds L2 ({self.l2_bytes/1e6:.0f} MB). "
                    f"Provide pre-expanded pt_int8 for DRAM path. "
                    f"Use TritGEMV.expand_int4_to_int8(pt_int4) at model load time."
                )
            self._lib.trit_gemv_d3_int8_dp4a(
                self._get_ptr(pt_int8), self._get_ptr(ws),
                self._get_ptr(xt), self._get_ptr(xs),
                self._get_ptr(y), cols, rows, num_groups, 0)

    @staticmethod
    def expand_int4_to_int8(pt_int4, device='cuda'):
        """Pre-expand int4 packed weights to int8 for DRAM-bound layers.

        Called once at model load. Uses 2x more VRAM but eliminates decode overhead.
        int4: 8 words per group β†’ int8: 16 words per group

        Args:
            pt_int4: int32 tensor [n_groups * 8] β€” int4 packed
        Returns:
            int32 tensor [n_groups * 16] β€” int8 packed (dp4a compatible)
        """
        import torch
        n_words = pt_int4.shape[0]
        n_groups = n_words // 8

        # Each int4 word has 8 nibbles β†’ 8 int8 values β†’ 2 int8x4 words
        pt_int8 = torch.zeros(n_groups * 16, dtype=torch.int32, device=device)

        # Expand on GPU (vectorized)
        for g in range(n_groups):
            for w in range(8):
                word = pt_int4[g * 8 + w].item()
                for nib in range(8):
                    val = (word >> (nib * 4)) & 0xF
                    if val & 0x8:
                        val = val | 0xFFFFFFF0  # sign extend
                    val = val & 0xFF
                    out_col = w * 8 + nib
                    out_word = out_col // 4
                    out_byte = out_col % 4
                    pt_int8[g * 16 + out_word] |= (val << (out_byte * 8))

        return pt_int8

    def gemv_d3(self, pt, sc, x, y, cols, rows, depth=3):
        """D3 GEMV with native trit packing.

        Args:
            pt: int32 tensor [rows * ng * 13] β€” trit packed weights
            sc: float32 tensor [rows * ng] β€” scales
            x:  float32 tensor [cols] β€” activations
            y:  float32 tensor [rows] β€” output
        """
        self._lib.trit_gemv_d3_native(
            self._get_ptr(pt), self._get_ptr(sc),
            self._get_ptr(x), self._get_ptr(y),
            cols, rows, depth,
        )

    def gemv_d3_int8(self, wt, ws, xt, xs, y, cols, rows, num_groups, l2_persist=True):
        """D3 GEMV with int8 level packing + dp4a (same quality as d3, dp4a speed).

        Args:
            wt: int32 tensor [rows * num_groups * 16] β€” int8 packed levels
            ws: float32 tensor [rows * num_groups] β€” weight scales
            xt: int32 tensor [num_groups * 16] β€” int8 packed activations
            xs: float32 tensor [num_groups * 16] β€” per-word x scales
            y:  float32 tensor [rows] β€” output
        """
        if not hasattr(self._lib, 'trit_gemv_d3_int8_dp4a'):
            raise RuntimeError("d3 int8 not in this build β€” rebuild libtrit_gemv.so")
        self._lib.trit_gemv_d3_int8_dp4a(
            self._get_ptr(wt), self._get_ptr(ws),
            self._get_ptr(xt), self._get_ptr(xs),
            self._get_ptr(y), cols, rows, num_groups,
            1 if l2_persist else 0,
        )

    def __repr__(self):
        return f"TritGEMV(gpu='{self.gpu_name}', l2={self.l2_bytes/1e6:.0f}MB)"