| #!/usr/bin/env bash |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| set -euo pipefail |
|
|
| if [ $# -ge 1 ]; then |
| COMPILER_PY="$1" |
| else |
| COMPILER_PY="$(python -c "import triton.backends.nvidia.compiler as c; print(c.__file__)")" |
| fi |
|
|
| if [ ! -f "$COMPILER_PY" ]; then |
| echo "ERROR: Cannot find Triton compiler.py at: $COMPILER_PY" >&2 |
| exit 1 |
| fi |
|
|
| |
|
|
| if grep -q 'major == 13' "$COMPILER_PY"; then |
| echo "Triton compiler.py already patched for CUDA 13.x" |
| else |
| if ! grep -q 'major == 12' "$COMPILER_PY"; then |
| echo "ERROR: Cannot find 'major == 12' in $COMPILER_PY — unexpected Triton version?" >&2 |
| exit 1 |
| fi |
| |
| sed -i '/if major == 12:/i\ if major == 13:' "$COMPILER_PY" |
| sed -i '/if major == 13:/a\ return 90 + minor' "$COMPILER_PY" |
| echo "Patched $COMPILER_PY to support CUDA 13.x" |
| fi |
|
|
| |
|
|
| SITE_PACKAGES="$(python -c "import site; print(site.getsitepackages()[0])")" |
| PTH_FILE="${SITE_PACKAGES}/triton_cuda13_patch.pth" |
|
|
| cat > "$PTH_FILE" << 'PTHEOF' |
| import triton_cuda13_patch |
| PTHEOF |
|
|
| cat > "${SITE_PACKAGES}/triton_cuda13_patch.py" << 'PYEOF' |
| """Monkey-patch Triton to support CUDA 13.x (installed by patch_triton_cuda13.sh).""" |
| def _apply(): |
| try: |
| from triton.backends.nvidia import compiler as _c |
| _orig = _c.ptx_get_version |
| def _patched(cuda_version): |
| major, minor = map(int, cuda_version.split('.')) |
| if major == 13: |
| return 90 + minor |
| return _orig(cuda_version) |
| _c.ptx_get_version = _patched |
| except (ImportError, AttributeError): |
| pass |
| _apply() |
| del _apply |
| PYEOF |
|
|
| echo "Installed ${PTH_FILE} (runtime monkey-patch, survives uv reinstalls)" |
|
|