File size: 3,165 Bytes
76f9669 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NVIDIA-SOFTWARE-LICENSE
import re
# Mapping based on the official PTX ISA <-> CUDA Release table
# https://docs.nvidia.com/cuda/parallel-thread-execution/#release-notes-ptx-release-history
_ptx_to_cuda = {
"1.0": (1, 0),
"1.1": (1, 1),
"1.2": (2, 0),
"1.3": (2, 1),
"1.4": (2, 2),
"2.0": (3, 0),
"2.1": (3, 1),
"2.2": (3, 2),
"2.3": (4, 0),
"3.0": (4, 1),
"3.1": (5, 0),
"3.2": (5, 5),
"4.0": (6, 0),
"4.1": (6, 5),
"4.2": (7, 0),
"4.3": (7, 5),
"5.0": (8, 0),
"6.0": (9, 0),
"6.1": (9, 1),
"6.2": (9, 2),
"6.3": (10, 0),
"6.4": (10, 1),
"6.5": (10, 2),
"7.0": (11, 0),
"7.1": (11, 1),
"7.2": (11, 2),
"7.3": (11, 3),
"7.4": (11, 4),
"7.5": (11, 5),
"7.6": (11, 6),
"7.7": (11, 7),
"7.8": (11, 8),
"8.0": (12, 0),
"8.1": (12, 1),
"8.2": (12, 2),
"8.3": (12, 3),
"8.4": (12, 4),
"8.5": (12, 5),
"8.6": (12, 7),
"8.7": (12, 8),
"8.8": (12, 9),
}
def get_minimal_required_cuda_ver_from_ptx_ver(ptx_version: str) -> int:
"""
Maps the PTX ISA version to the minimal CUDA driver, nvPTXCompiler, or nvJitLink version
that is needed to load a PTX of the given ISA version.
Parameters
----------
ptx_version : str
PTX ISA version as a string, e.g. "8.8" for PTX ISA 8.8. This is the ``.version``
directive in the PTX header.
Returns
-------
int
Minimal CUDA version as 1000 * major + 10 * minor, e.g. 12090 for CUDA 12.9.
Raises
------
ValueError
If the PTX version is unknown.
Examples
--------
>>> get_minimal_required_driver_ver_from_ptx_ver("8.8")
12090
>>> get_minimal_required_driver_ver_from_ptx_ver("7.0")
11000
"""
try:
major, minor = _ptx_to_cuda[ptx_version]
return 1000 * major + 10 * minor
except KeyError:
raise ValueError(f"Unknown or unsupported PTX ISA version: {ptx_version}") from None
# Regex pattern to match .version directive and capture the version number
# TODO: if import speed is a concern, consider lazy-initializing it.
_ptx_ver_pattern = re.compile(r"\.version\s+([0-9]+\.[0-9]+)")
def get_ptx_ver(ptx: str) -> str:
"""
Extract the PTX ISA version string from PTX source code.
Parameters
----------
ptx : str
The PTX assembly source code as a string.
Returns
-------
str
The PTX ISA version string, e.g., "8.8".
Raises
------
ValueError
If the .version directive is not found in the PTX source.
Examples
--------
>>> ptx = r'''
... .version 8.8
... .target sm_86
... .address_size 64
...
... .visible .entry test_kernel()
... {
... ret;
... }
... '''
>>> get_ptx_ver(ptx)
'8.8'
"""
m = _ptx_ver_pattern.search(ptx)
if m:
return m.group(1)
else:
raise ValueError("No .version directive found in PTX source. Is it a valid PTX?")
|