/* * TritLLM CUDA Kernel — Ternary GEMV (Matrix-Vector Multiply) * * Core operation: y = W_ternary @ x * Where W_ternary is packed ternary weights with per-group scales. * * Each group of 64 weights has: * - A depth (1-4 trits per weight) * - A FP16 scale factor * - Packed trit values (2 bits per trit: 00=0, 01=+1, 10=-1, 11=unused) * * The key: NO floating-point multiply in the inner loop. * Ternary MAC = conditional add/subtract. */ #include #include #include #define GROUP_SIZE 64 #define WARP_SIZE 32 // Trit encoding: 2 bits per trit // 00 = 0, 01 = +1, 10 = -1 #define TRIT_ZERO 0 #define TRIT_POS 1 #define TRIT_NEG 2 /* * Depth 1 (3 levels: {-1, 0, +1}): 1 trit per weight, 2 bits per weight * Pack 16 trits per uint32 (16 * 2 = 32 bits) * Group of 64 = 4 uint32s * * Inner loop: read trit, branch-free conditional accumulate */ __device__ __forceinline__ float trit_mac_d1( const uint32_t* __restrict__ packed, // 4 uint32s = 64 trits const float* __restrict__ x, // 64 activations int lane // warp lane (0-31) ) { float acc = 0.0f; // Each thread in warp handles 2 elements (64 / 32 = 2) #pragma unroll for (int i = 0; i < 2; i++) { int idx = lane * 2 + i; int word = idx / 16; // which uint32 (0-3) int bit_offset = (idx % 16) * 2; // bit position within word uint32_t trit = (packed[word] >> bit_offset) & 0x3; float val = x[idx]; // Branch-free: acc += (trit == 1) * val - (trit == 2) * val acc += ((trit == TRIT_POS) - (trit == TRIT_NEG)) * val; } return acc; } /* * Depth 2 (9 levels: {-4..+4}): 2 trits per weight, 4 bits per weight * Trit value = trit1 * 3 + trit0 - 4 (maps to -4..+4) * Pack 8 values per uint32 (8 * 4 = 32 bits) * Group of 64 = 8 uint32s */ __device__ __forceinline__ float trit_mac_d2( const uint32_t* __restrict__ packed, // 8 uint32s = 64 values const float* __restrict__ x, int lane ) { float acc = 0.0f; #pragma unroll for (int i = 0; i < 2; i++) { int idx = lane * 2 + i; int word = idx / 8; int bit_offset = (idx % 8) * 4; uint32_t bits = (packed[word] >> bit_offset) & 0xF; // Decode: trit1 = bits >> 2, trit0 = bits & 0x3 // value = (trit1_sign * 3 + trit0_sign) // where trit_sign: 00->0, 01->+1, 10->-1 int t0 = (int)(bits & 0x3); int t1 = (int)((bits >> 2) & 0x3); int sign0 = (t0 == TRIT_POS) - (t0 == TRIT_NEG); int sign1 = (t1 == TRIT_POS) - (t1 == TRIT_NEG); int level = sign1 * 3 + sign0; // -4 to +4 // Still no FP multiply — integer * float is one instruction // level is small integer, compiler optimizes to repeated add acc += level * x[idx]; } return acc; } /* * Depth 3 (27 levels: {-13..+13}): 3 trits per weight, 6 bits per weight * Pack 5 values per uint32 (5 * 6 = 30 bits, 2 wasted) * Group of 64 = 13 uint32s (64 values, last uint32 has 4 values) */ __device__ __forceinline__ float trit_mac_d3( const uint32_t* __restrict__ packed, // 13 uint32s const float* __restrict__ x, int lane ) { float acc = 0.0f; #pragma unroll for (int i = 0; i < 2; i++) { int idx = lane * 2 + i; int word = idx / 5; int pos = idx % 5; int bit_offset = pos * 6; uint32_t bits = (packed[word] >> bit_offset) & 0x3F; int t0 = (int)(bits & 0x3); int t1 = (int)((bits >> 2) & 0x3); int t2 = (int)((bits >> 4) & 0x3); int s0 = (t0 == TRIT_POS) - (t0 == TRIT_NEG); int s1 = (t1 == TRIT_POS) - (t1 == TRIT_NEG); int s2 = (t2 == TRIT_POS) - (t2 == TRIT_NEG); int level = s2 * 9 + s1 * 3 + s0; // -13 to +13 acc += level * x[idx]; } return acc; } /* * Depth 4 (81 levels: {-40..+40}): 4 trits per weight, 8 bits per weight * Pack 4 values per uint32 (4 * 8 = 32 bits, perfect) * Group of 64 = 16 uint32s */ __device__ __forceinline__ float trit_mac_d4( const uint32_t* __restrict__ packed, // 16 uint32s const float* __restrict__ x, int lane ) { float acc = 0.0f; #pragma unroll for (int i = 0; i < 2; i++) { int idx = lane * 2 + i; int word = idx / 4; int bit_offset = (idx % 4) * 8; uint32_t bits = (packed[word] >> bit_offset) & 0xFF; int t0 = (int)(bits & 0x3); int t1 = (int)((bits >> 2) & 0x3); int t2 = (int)((bits >> 4) & 0x3); int t3 = (int)((bits >> 6) & 0x3); int s0 = (t0 == TRIT_POS) - (t0 == TRIT_NEG); int s1 = (t1 == TRIT_POS) - (t1 == TRIT_NEG); int s2 = (t2 == TRIT_POS) - (t2 == TRIT_NEG); int s3 = (t3 == TRIT_POS) - (t3 == TRIT_NEG); int level = s3 * 27 + s2 * 9 + s1 * 3 + s0; acc += level * x[idx]; } return acc; } /* * Main GEMV kernel: y[out_features] = W[out_features, in_features] @ x[in_features] * * W is stored as packed ternary groups: * - packed_trits: variable-length packed trit data per group * - scales: FP16 scale per group * - depths: uint8 depth per group (1-4) * - group_offsets: byte offset into packed_trits for each group * * One warp per output row, iterating over groups along the input dimension. * Warp reduction gives the final dot product. */ // Simplified version: uniform depth across all groups in a tensor // (variable-depth version below) // Launch contract: blockDim.x == 32 (one warp per block), in_features % 64 == 0. // The kernel uses lane = threadIdx.x and a full-warp shuffle mask, so larger // blocks would alias the lane index and race on y[row]. Trailing partial groups // are an unsupported shape, not silently dropped. __global__ void trit_gemv_uniform( const uint32_t* __restrict__ packed_trits, // packed trit data const float* __restrict__ scales, // [num_groups] FP16 stored as float const float* __restrict__ x, // [in_features] float* __restrict__ y, // [out_features] int in_features, int out_features, int depth // uniform depth 1-4 ) { if (blockDim.x != WARP_SIZE) return; // launch contract: 1 warp/block if (in_features % GROUP_SIZE) return; // launch contract: K mod 64 == 0 int row = blockIdx.x; // one block per output row if (row >= out_features) return; int lane = threadIdx.x; // lane within warp (0-31) int num_groups = in_features / GROUP_SIZE; // Words per group depends on depth int words_per_group; switch (depth) { case 1: words_per_group = 4; break; // 64 * 2 / 32 case 2: words_per_group = 8; break; // 64 * 4 / 32 case 3: words_per_group = 13; break; // ceil(64 * 6 / 32) case 4: words_per_group = 16; break; // 64 * 8 / 32 default: words_per_group = 4; break; } float row_acc = 0.0f; for (int g = 0; g < num_groups; g++) { int group_offset = (row * num_groups + g) * words_per_group; const uint32_t* group_data = &packed_trits[group_offset]; const float* group_x = &x[g * GROUP_SIZE]; float scale = scales[row * num_groups + g]; float group_acc; switch (depth) { case 1: group_acc = trit_mac_d1(group_data, group_x, lane); break; case 2: group_acc = trit_mac_d2(group_data, group_x, lane); break; case 3: group_acc = trit_mac_d3(group_data, group_x, lane); break; case 4: group_acc = trit_mac_d4(group_data, group_x, lane); break; default: group_acc = 0.0f; break; } // Warp reduction #pragma unroll for (int offset = 16; offset > 0; offset >>= 1) { group_acc += __shfl_down_sync(0xFFFFFFFF, group_acc, offset); } // Lane 0 accumulates the scaled result if (lane == 0) { row_acc += group_acc * scale; } } // Write output if (lane == 0) { y[row] = row_acc; } } /* * Variable-depth version: each group can have a different depth. * Uses a depth map and offset table to handle mixed-depth tensors. */ // Launch contract: blockDim.x == 32 (one warp per block), in_features % 64 == 0. __global__ void trit_gemv_variable( const uint32_t* __restrict__ packed_trits, const float* __restrict__ scales, const uint8_t* __restrict__ depth_map, // [num_groups_per_row] depth per group const int* __restrict__ group_offsets, // [num_groups_per_row + 1] word offsets const float* __restrict__ x, float* __restrict__ y, int in_features, int out_features ) { if (blockDim.x != WARP_SIZE) return; if (in_features % GROUP_SIZE) return; int row = blockIdx.x; if (row >= out_features) return; int lane = threadIdx.x; int num_groups = in_features / GROUP_SIZE; float row_acc = 0.0f; for (int g = 0; g < num_groups; g++) { int depth = depth_map[g]; int word_offset = group_offsets[g] + row * group_offsets[num_groups]; // row stride const uint32_t* group_data = &packed_trits[word_offset]; const float* group_x = &x[g * GROUP_SIZE]; float scale = scales[row * num_groups + g]; float group_acc; switch (depth) { case 1: group_acc = trit_mac_d1(group_data, group_x, lane); break; case 2: group_acc = trit_mac_d2(group_data, group_x, lane); break; case 3: group_acc = trit_mac_d3(group_data, group_x, lane); break; case 4: group_acc = trit_mac_d4(group_data, group_x, lane); break; default: group_acc = 0.0f; break; } #pragma unroll for (int offset = 16; offset > 0; offset >>= 1) { group_acc += __shfl_down_sync(0xFFFFFFFF, group_acc, offset); } if (lane == 0) { row_acc += group_acc * scale; } } if (lane == 0) { y[row] = row_acc; } }