|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import contextlib |
|
|
|
|
|
import numpy as np |
|
|
import torch |
|
|
from packaging.version import Version |
|
|
|
|
|
__CUDA_PYTHON_MINIMUM_VERSION_CUDA_GRAPH_CONDITIONAL_NODES_SUPPORTED__ = (12, 3) |
|
|
|
|
|
|
|
|
def check_cuda_python_cuda_graphs_conditional_nodes_supported(): |
|
|
|
|
|
if not torch.cuda.is_available(): |
|
|
raise EnvironmentError("CUDA is not available") |
|
|
|
|
|
try: |
|
|
from cuda import cuda |
|
|
except ImportError: |
|
|
raise ModuleNotFoundError("No `cuda-python` module. Please do `pip install cuda-python>=12.3`") |
|
|
|
|
|
from cuda import __version__ as cuda_python_version |
|
|
|
|
|
if Version(cuda_python_version) < Version("12.3.0"): |
|
|
raise ImportError(f"Found cuda-python {cuda_python_version}, but at least version 12.3.0 is needed.") |
|
|
|
|
|
error, driver_version = cuda.cuDriverGetVersion() |
|
|
if error != cuda.CUresult.CUDA_SUCCESS: |
|
|
raise ImportError(f"cuDriverGetVersion() returned {cuda.cuGetErrorString(error)}") |
|
|
|
|
|
driver_version_major = driver_version // 1000 |
|
|
driver_version_minor = (driver_version % 1000) // 10 |
|
|
|
|
|
driver_version = (driver_version_major, driver_version_minor) |
|
|
if driver_version < __CUDA_PYTHON_MINIMUM_VERSION_CUDA_GRAPH_CONDITIONAL_NODES_SUPPORTED__: |
|
|
required_version = __CUDA_PYTHON_MINIMUM_VERSION_CUDA_GRAPH_CONDITIONAL_NODES_SUPPORTED__ |
|
|
raise ImportError( |
|
|
f"""Driver supports cuda toolkit version \ |
|
|
{driver_version_major}.{driver_version_minor}, but the driver needs to support \ |
|
|
at least {required_version[0]},{required_version[1]}. Please update your cuda driver.""" |
|
|
) |
|
|
|
|
|
|
|
|
def skip_cuda_python_test_if_cuda_graphs_conditional_nodes_not_supported(): |
|
|
""" |
|
|
Helper method to skip pytest test case if cuda graph conditionals nodes are not supported. |
|
|
""" |
|
|
try: |
|
|
check_cuda_python_cuda_graphs_conditional_nodes_supported() |
|
|
except (ImportError, ModuleNotFoundError, EnvironmentError) as e: |
|
|
import pytest |
|
|
|
|
|
pytest.skip( |
|
|
"Test using cuda graphs with conditional nodes is being skipped because " |
|
|
f"cuda graphs with conditional nodes aren't supported. Error message: {e}" |
|
|
) |
|
|
|
|
|
|
|
|
def assert_drv(err): |
|
|
""" |
|
|
Throws an exception if the return value of a cuda-python call is not success. |
|
|
""" |
|
|
from cuda import cuda, cudart, nvrtc |
|
|
|
|
|
if isinstance(err, cuda.CUresult): |
|
|
if err != cuda.CUresult.CUDA_SUCCESS: |
|
|
raise RuntimeError("Cuda Error: {}".format(err)) |
|
|
elif isinstance(err, nvrtc.nvrtcResult): |
|
|
if err != nvrtc.nvrtcResult.NVRTC_SUCCESS: |
|
|
raise RuntimeError("Nvrtc Error: {}".format(err)) |
|
|
elif isinstance(err, cudart.cudaError_t): |
|
|
if err != cudart.cudaError_t.cudaSuccess: |
|
|
raise RuntimeError("Cuda Runtime Error: {}".format(err)) |
|
|
else: |
|
|
raise RuntimeError("Unknown error type: {}".format(err)) |
|
|
|
|
|
|
|
|
def cu_call(f_call_out): |
|
|
""" |
|
|
Makes calls to cuda-python's functions inside cuda.cuda more python by throwing an exception if they return a status which is not cudaSuccess |
|
|
""" |
|
|
from cuda import cudart |
|
|
|
|
|
error, *others = f_call_out |
|
|
if error != cudart.cudaError_t.cudaSuccess: |
|
|
raise Exception(f"CUDA failure! {error}") |
|
|
else: |
|
|
return tuple(others) |
|
|
|
|
|
|
|
|
@contextlib.contextmanager |
|
|
def with_conditional_node(while_loop_kernel, while_loop_args, while_loop_conditional_handle, device): |
|
|
""" |
|
|
Even though we add a conditional node only once, we need to |
|
|
capture the kernel that calls cudaGraphSetConditional() both |
|
|
before in the parent graph containing the while loop body graph |
|
|
and after the rest of the while loop body graph (because we need |
|
|
to decide both whether to enter the loop, and also whether to |
|
|
execute the next iteration of the loop). |
|
|
""" |
|
|
from cuda import __version__ as cuda_python_version |
|
|
from cuda import cuda, cudart, nvrtc |
|
|
|
|
|
capture_status, _, graph, _, _ = cu_call( |
|
|
cudart.cudaStreamGetCaptureInfo(torch.cuda.current_stream(device=device).cuda_stream) |
|
|
) |
|
|
assert capture_status == cudart.cudaStreamCaptureStatus.cudaStreamCaptureStatusActive |
|
|
|
|
|
cuda.cuLaunchKernel( |
|
|
while_loop_kernel, |
|
|
1, |
|
|
1, |
|
|
1, |
|
|
1, |
|
|
1, |
|
|
1, |
|
|
0, |
|
|
torch.cuda.current_stream(device=device).cuda_stream, |
|
|
while_loop_args.ctypes.data, |
|
|
0, |
|
|
) |
|
|
|
|
|
capture_status, _, graph, dependencies, _ = cu_call( |
|
|
cudart.cudaStreamGetCaptureInfo(torch.cuda.current_stream(device=device).cuda_stream) |
|
|
) |
|
|
assert capture_status == cudart.cudaStreamCaptureStatus.cudaStreamCaptureStatusActive |
|
|
|
|
|
driver_params = cuda.CUgraphNodeParams() |
|
|
driver_params.type = cuda.CUgraphNodeType.CU_GRAPH_NODE_TYPE_CONDITIONAL |
|
|
driver_params.conditional.handle = while_loop_conditional_handle |
|
|
driver_params.conditional.type = cuda.CUgraphConditionalNodeType.CU_GRAPH_COND_TYPE_WHILE |
|
|
driver_params.conditional.size = 1 |
|
|
if Version(cuda_python_version) == Version("12.3.0"): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
driver_params.conditional.phGraph_out = [cuda.CUgraph()] |
|
|
(ctx,) = cu_call(cuda.cuCtxGetCurrent()) |
|
|
driver_params.conditional.ctx = ctx |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
(node,) = cu_call(cuda.cuGraphAddNode(graph, dependencies, len(dependencies), driver_params)) |
|
|
body_graph = driver_params.conditional.phGraph_out[0] |
|
|
|
|
|
cu_call( |
|
|
cudart.cudaStreamUpdateCaptureDependencies( |
|
|
torch.cuda.current_stream(device=device).cuda_stream, |
|
|
[node], |
|
|
1, |
|
|
cudart.cudaStreamUpdateCaptureDependenciesFlags.cudaStreamSetCaptureDependencies, |
|
|
) |
|
|
) |
|
|
body_stream = torch.cuda.Stream(device) |
|
|
previous_stream = torch.cuda.current_stream(device=device) |
|
|
cu_call( |
|
|
cudart.cudaStreamBeginCaptureToGraph( |
|
|
body_stream.cuda_stream, |
|
|
body_graph, |
|
|
None, |
|
|
None, |
|
|
0, |
|
|
cudart.cudaStreamCaptureMode.cudaStreamCaptureModeThreadLocal, |
|
|
) |
|
|
) |
|
|
torch.cuda.set_stream(body_stream) |
|
|
|
|
|
yield body_stream, body_graph |
|
|
|
|
|
cuda.cuLaunchKernel( |
|
|
while_loop_kernel, 1, 1, 1, 1, 1, 1, 0, body_stream.cuda_stream, while_loop_args.ctypes.data, 0 |
|
|
) |
|
|
|
|
|
cudart.cudaStreamEndCapture(body_stream.cuda_stream) |
|
|
|
|
|
torch.cuda.set_stream(previous_stream) |
|
|
|
|
|
|
|
|
def run_nvrtc(kernel_string: str, kernel_name: bytes, program_name: bytes): |
|
|
from cuda import cuda, nvrtc |
|
|
|
|
|
err, prog = nvrtc.nvrtcCreateProgram(str.encode(kernel_string), program_name, 0, [], []) |
|
|
assert_drv(err) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
opts = [] |
|
|
(err,) = nvrtc.nvrtcCompileProgram(prog, len(opts), opts) |
|
|
assert_drv(err) |
|
|
err, size = nvrtc.nvrtcGetProgramLogSize(prog) |
|
|
assert_drv(err) |
|
|
buf = b" " * size |
|
|
(err,) = nvrtc.nvrtcGetProgramLog(prog, buf) |
|
|
assert_drv(err) |
|
|
|
|
|
|
|
|
err, ptxSize = nvrtc.nvrtcGetPTXSize(prog) |
|
|
assert_drv(err) |
|
|
ptx = b" " * ptxSize |
|
|
(err,) = nvrtc.nvrtcGetPTX(prog, ptx) |
|
|
assert_drv(err) |
|
|
|
|
|
ptx = np.char.array(ptx) |
|
|
err, module = cuda.cuModuleLoadData(ptx.ctypes.data) |
|
|
assert_drv(err) |
|
|
err, kernel = cuda.cuModuleGetFunction(module, kernel_name) |
|
|
assert_drv(err) |
|
|
|
|
|
return kernel |
|
|
|