| #!/bin/bash |
| |
| |
| |
| |
| |
| |
|
|
| set -e |
| cd "$(dirname "$0")" |
|
|
| |
| NVCC=$(which nvcc 2>/dev/null || echo "/usr/local/cuda/bin/nvcc") |
| if [ ! -x "$NVCC" ]; then |
| echo "ERROR: nvcc not found. Install CUDA toolkit." |
| exit 1 |
| fi |
|
|
| echo "Using nvcc: $NVCC" |
| $NVCC --version | head -1 |
|
|
| |
| |
| |
| ARCHS="" |
| ARCHS="$ARCHS -gencode=arch=compute_70,code=sm_70" |
| ARCHS="$ARCHS -gencode=arch=compute_75,code=sm_75" |
| ARCHS="$ARCHS -gencode=arch=compute_80,code=sm_80" |
| ARCHS="$ARCHS -gencode=arch=compute_86,code=sm_86" |
| ARCHS="$ARCHS -gencode=arch=compute_89,code=sm_89" |
| ARCHS="$ARCHS -gencode=arch=compute_90,code=sm_90" |
|
|
| |
| if $NVCC --help 2>&1 | grep -q "compute_120"; then |
| ARCHS="$ARCHS -gencode=arch=compute_120,code=sm_120" |
| echo "Including Blackwell (sm_120)" |
| fi |
|
|
| echo "Building libtrit_gemv.so..." |
| $NVCC -O3 --use_fast_math \ |
| -shared -Xcompiler -fPIC \ |
| $ARCHS \ |
| -o libtrit_gemv.so \ |
| trit_gemv_standalone.cu |
|
|
| ls -la libtrit_gemv.so |
| echo "Done! Library ready at $(pwd)/libtrit_gemv.so" |
| echo "" |
| echo "Usage from Python:" |
| echo " from trit_gemv_lib import TritGEMV" |
| echo " lib = TritGEMV()" |
| echo " lib.gemv_d2(weights, scales, x_int8, x_scales, output, K, M, ng)" |
|
|