tritllm-kernel / build.sh
Entrit's picture
initial public release: code, README, KNOWN_ISSUES
51e3123 verified
#!/bin/bash
# Build libtrit_gemv.so — standalone CUDA kernel library
# No PyTorch, no Python, no framework dependency.
# Just nvcc + CUDA runtime.
#
# Fat binary: compiles for all major GPU architectures.
# The right kernel is selected at runtime based on the GPU.
set -e
cd "$(dirname "$0")"
# Detect nvcc
NVCC=$(which nvcc 2>/dev/null || echo "/usr/local/cuda/bin/nvcc")
if [ ! -x "$NVCC" ]; then
echo "ERROR: nvcc not found. Install CUDA toolkit."
exit 1
fi
echo "Using nvcc: $NVCC"
$NVCC --version | head -1
# Architecture targets (fat binary)
# Volta (V100), Turing (2080), Ampere (3090/A100),
# Ada (4080/4090), Hopper (H100), Blackwell (5070+)
ARCHS=""
ARCHS="$ARCHS -gencode=arch=compute_70,code=sm_70" # V100
ARCHS="$ARCHS -gencode=arch=compute_75,code=sm_75" # 2080
ARCHS="$ARCHS -gencode=arch=compute_80,code=sm_80" # A100, 3080
ARCHS="$ARCHS -gencode=arch=compute_86,code=sm_86" # 3090
ARCHS="$ARCHS -gencode=arch=compute_89,code=sm_89" # 4080, 4090
ARCHS="$ARCHS -gencode=arch=compute_90,code=sm_90" # H100
# Blackwell — only if nvcc supports it (CUDA 12.8+)
if $NVCC --help 2>&1 | grep -q "compute_120"; then
ARCHS="$ARCHS -gencode=arch=compute_120,code=sm_120"
echo "Including Blackwell (sm_120)"
fi
echo "Building libtrit_gemv.so..."
$NVCC -O3 --use_fast_math \
-shared -Xcompiler -fPIC \
$ARCHS \
-o libtrit_gemv.so \
trit_gemv_standalone.cu
ls -la libtrit_gemv.so
echo "Done! Library ready at $(pwd)/libtrit_gemv.so"
echo ""
echo "Usage from Python:"
echo " from trit_gemv_lib import TritGEMV"
echo " lib = TritGEMV()"
echo " lib.gemv_d2(weights, scales, x_int8, x_scales, output, K, M, ng)"