| | |
| | |
| |
|
| | import re |
| |
|
| | |
| | |
| | _ptx_to_cuda = { |
| | "1.0": (1, 0), |
| | "1.1": (1, 1), |
| | "1.2": (2, 0), |
| | "1.3": (2, 1), |
| | "1.4": (2, 2), |
| | "2.0": (3, 0), |
| | "2.1": (3, 1), |
| | "2.2": (3, 2), |
| | "2.3": (4, 0), |
| | "3.0": (4, 1), |
| | "3.1": (5, 0), |
| | "3.2": (5, 5), |
| | "4.0": (6, 0), |
| | "4.1": (6, 5), |
| | "4.2": (7, 0), |
| | "4.3": (7, 5), |
| | "5.0": (8, 0), |
| | "6.0": (9, 0), |
| | "6.1": (9, 1), |
| | "6.2": (9, 2), |
| | "6.3": (10, 0), |
| | "6.4": (10, 1), |
| | "6.5": (10, 2), |
| | "7.0": (11, 0), |
| | "7.1": (11, 1), |
| | "7.2": (11, 2), |
| | "7.3": (11, 3), |
| | "7.4": (11, 4), |
| | "7.5": (11, 5), |
| | "7.6": (11, 6), |
| | "7.7": (11, 7), |
| | "7.8": (11, 8), |
| | "8.0": (12, 0), |
| | "8.1": (12, 1), |
| | "8.2": (12, 2), |
| | "8.3": (12, 3), |
| | "8.4": (12, 4), |
| | "8.5": (12, 5), |
| | "8.6": (12, 7), |
| | "8.7": (12, 8), |
| | "8.8": (12, 9), |
| | } |
| |
|
| |
|
| | def get_minimal_required_cuda_ver_from_ptx_ver(ptx_version: str) -> int: |
| | """ |
| | Maps the PTX ISA version to the minimal CUDA driver, nvPTXCompiler, or nvJitLink version |
| | that is needed to load a PTX of the given ISA version. |
| | |
| | Parameters |
| | ---------- |
| | ptx_version : str |
| | PTX ISA version as a string, e.g. "8.8" for PTX ISA 8.8. This is the ``.version`` |
| | directive in the PTX header. |
| | |
| | Returns |
| | ------- |
| | int |
| | Minimal CUDA version as 1000 * major + 10 * minor, e.g. 12090 for CUDA 12.9. |
| | |
| | Raises |
| | ------ |
| | ValueError |
| | If the PTX version is unknown. |
| | |
| | Examples |
| | -------- |
| | >>> get_minimal_required_driver_ver_from_ptx_ver("8.8") |
| | 12090 |
| | >>> get_minimal_required_driver_ver_from_ptx_ver("7.0") |
| | 11000 |
| | """ |
| | try: |
| | major, minor = _ptx_to_cuda[ptx_version] |
| | return 1000 * major + 10 * minor |
| | except KeyError: |
| | raise ValueError(f"Unknown or unsupported PTX ISA version: {ptx_version}") from None |
| |
|
| |
|
| | |
| | |
| | _ptx_ver_pattern = re.compile(r"\.version\s+([0-9]+\.[0-9]+)") |
| |
|
| |
|
| | def get_ptx_ver(ptx: str) -> str: |
| | """ |
| | Extract the PTX ISA version string from PTX source code. |
| | |
| | Parameters |
| | ---------- |
| | ptx : str |
| | The PTX assembly source code as a string. |
| | |
| | Returns |
| | ------- |
| | str |
| | The PTX ISA version string, e.g., "8.8". |
| | |
| | Raises |
| | ------ |
| | ValueError |
| | If the .version directive is not found in the PTX source. |
| | |
| | Examples |
| | -------- |
| | >>> ptx = r''' |
| | ... .version 8.8 |
| | ... .target sm_86 |
| | ... .address_size 64 |
| | ... |
| | ... .visible .entry test_kernel() |
| | ... { |
| | ... ret; |
| | ... } |
| | ... ''' |
| | >>> get_ptx_ver(ptx) |
| | '8.8' |
| | """ |
| | m = _ptx_ver_pattern.search(ptx) |
| | if m: |
| | return m.group(1) |
| | else: |
| | raise ValueError("No .version directive found in PTX source. Is it a valid PTX?") |
| |
|