|
|
import ctypes |
|
|
import sys |
|
|
from typing import Any, Optional, Union |
|
|
|
|
|
import torch |
|
|
|
|
|
|
|
|
from torch._utils import _get_device_index as _torch_get_device_index |
|
|
|
|
|
|
|
|
|
|
|
def _get_cuda_library() -> ctypes.CDLL: |
|
|
if sys.platform == "win32": |
|
|
return ctypes.CDLL("nvcuda.dll") |
|
|
else: |
|
|
return ctypes.CDLL("libcuda.so.1") |
|
|
|
|
|
|
|
|
|
|
|
def _check_cuda(result: int) -> None: |
|
|
if result == 0: |
|
|
return |
|
|
err_str = ctypes.c_char_p() |
|
|
libcuda = _get_cuda_library() |
|
|
libcuda.cuGetErrorString(result, ctypes.byref(err_str)) |
|
|
error_message = ( |
|
|
err_str.value.decode() if err_str.value is not None else "Unknown CUDA error" |
|
|
) |
|
|
raise RuntimeError(f"CUDA error: {error_message}") |
|
|
|
|
|
|
|
|
def _get_nvrtc_library() -> ctypes.CDLL: |
|
|
major_version = int(torch.version.cuda.split(".")[0]) |
|
|
if sys.platform == "win32": |
|
|
nvrtc_libs = [ |
|
|
f"nvrtc64_{major_version}0_0.dll", |
|
|
] |
|
|
else: |
|
|
nvrtc_libs = [ |
|
|
f"libnvrtc.so.{major_version}", |
|
|
"libnvrtc.so", |
|
|
] |
|
|
for lib_name in nvrtc_libs: |
|
|
try: |
|
|
return ctypes.CDLL(lib_name) |
|
|
except OSError: |
|
|
continue |
|
|
raise OSError("Could not find any NVRTC library") |
|
|
|
|
|
|
|
|
def _nvrtc_compile( |
|
|
kernel_source: str, |
|
|
kernel_name: str, |
|
|
compute_capability: Optional[str] = None, |
|
|
header_code: str = "", |
|
|
cuda_include_dirs: Optional[list] = None, |
|
|
nvcc_options: Optional[list] = None, |
|
|
) -> bytes: |
|
|
""" |
|
|
Compiles a CUDA kernel using NVRTC and returns the PTX code. |
|
|
|
|
|
Args: |
|
|
kernel_source (str): The CUDA kernel source code as a string |
|
|
kernel_name (str): The name of the kernel function to compile |
|
|
compute_capability (str, None): The compute capability to target (e.g., "86"). |
|
|
If None, will detect from current device. |
|
|
header_code (str, optional): Additional header code to prepend to the kernel source |
|
|
cuda_include_dirs (list, None): List of directories containing CUDA headers |
|
|
nvcc_options (list, None): Additional options to pass to NVRTC |
|
|
|
|
|
Returns: |
|
|
str: The compiled PTX code |
|
|
""" |
|
|
|
|
|
import torch.cuda |
|
|
|
|
|
|
|
|
libnvrtc = _get_nvrtc_library() |
|
|
|
|
|
|
|
|
NVRTC_SUCCESS = 0 |
|
|
|
|
|
|
|
|
def check_nvrtc(result: int) -> None: |
|
|
if result != NVRTC_SUCCESS: |
|
|
err_str = ctypes.c_char_p() |
|
|
libnvrtc.nvrtcGetErrorString(result, ctypes.byref(err_str)) |
|
|
error_message = ( |
|
|
err_str.value.decode() |
|
|
if err_str.value is not None |
|
|
else "Unknown CUDA error" |
|
|
) |
|
|
raise RuntimeError(f"CUDA error: {error_message}") |
|
|
|
|
|
|
|
|
if not kernel_source.strip().startswith('extern "C"'): |
|
|
kernel_source = f'extern "C" {kernel_source}' |
|
|
|
|
|
|
|
|
if header_code: |
|
|
full_source = header_code + "\n" + kernel_source |
|
|
else: |
|
|
full_source = kernel_source |
|
|
|
|
|
|
|
|
source_bytes = full_source.encode("utf-8") |
|
|
|
|
|
|
|
|
if compute_capability is None: |
|
|
props = torch.cuda.get_device_properties(torch.cuda.current_device()) |
|
|
compute_capability = f"{props.major}{props.minor}" |
|
|
|
|
|
|
|
|
options = [] |
|
|
options.append(f"--gpu-architecture=sm_{compute_capability}".encode()) |
|
|
|
|
|
|
|
|
if cuda_include_dirs: |
|
|
for directory in cuda_include_dirs: |
|
|
options.append(f"-I{directory}".encode()) |
|
|
|
|
|
|
|
|
if nvcc_options: |
|
|
for option in nvcc_options: |
|
|
options.append(option.encode("utf-8")) |
|
|
|
|
|
|
|
|
from torch.utils.cpp_extension import COMMON_NVCC_FLAGS |
|
|
|
|
|
|
|
|
nvrtc_compatible_flags = [ |
|
|
flag for flag in COMMON_NVCC_FLAGS if flag != "--expt-relaxed-constexpr" |
|
|
] |
|
|
options.extend([flag.encode("utf-8") for flag in nvrtc_compatible_flags]) |
|
|
|
|
|
|
|
|
num_options = len(options) |
|
|
options_array = (ctypes.c_char_p * num_options)(*options) |
|
|
|
|
|
|
|
|
prog = ctypes.c_void_p() |
|
|
check_nvrtc( |
|
|
libnvrtc.nvrtcCreateProgram( |
|
|
ctypes.byref(prog), |
|
|
source_bytes, |
|
|
f"{kernel_name}.cu".encode(), |
|
|
0, |
|
|
None, |
|
|
None, |
|
|
) |
|
|
) |
|
|
|
|
|
|
|
|
res = libnvrtc.nvrtcCompileProgram(prog, num_options, options_array) |
|
|
|
|
|
|
|
|
if res != NVRTC_SUCCESS: |
|
|
|
|
|
log_size = ctypes.c_size_t() |
|
|
libnvrtc.nvrtcGetProgramLogSize(prog, ctypes.byref(log_size)) |
|
|
log = ctypes.create_string_buffer(log_size.value) |
|
|
libnvrtc.nvrtcGetProgramLog(prog, log) |
|
|
raise RuntimeError(f"Kernel compilation failed:\n{log.value.decode()}") |
|
|
|
|
|
|
|
|
ptx_size = ctypes.c_size_t() |
|
|
check_nvrtc(libnvrtc.nvrtcGetPTXSize(prog, ctypes.byref(ptx_size))) |
|
|
ptx = ctypes.create_string_buffer(ptx_size.value) |
|
|
check_nvrtc(libnvrtc.nvrtcGetPTX(prog, ptx)) |
|
|
libnvrtc.nvrtcDestroyProgram(ctypes.byref(prog)) |
|
|
|
|
|
return ptx.value |
|
|
|
|
|
|
|
|
class _CudaModule: |
|
|
def __init__(self, module: ctypes.c_void_p) -> None: |
|
|
self._module = module |
|
|
self._kernels: dict[str, _CudaKernel] = {} |
|
|
|
|
|
def __getattr__(self, name: str) -> "_CudaKernel": |
|
|
if name in self._kernels: |
|
|
return self._kernels[name] |
|
|
|
|
|
|
|
|
from torch.cuda._utils import _get_cuda_library |
|
|
|
|
|
libcuda = _get_cuda_library() |
|
|
|
|
|
func = ctypes.c_void_p() |
|
|
try: |
|
|
_check_cuda( |
|
|
libcuda.cuModuleGetFunction( |
|
|
ctypes.byref(func), self._module, name.encode("utf-8") |
|
|
) |
|
|
) |
|
|
kernel = _CudaKernel(func, self._module) |
|
|
self._kernels[name] = kernel |
|
|
return kernel |
|
|
|
|
|
except RuntimeError as err: |
|
|
raise AttributeError(f"No kernel named '{name}' in this module") from err |
|
|
|
|
|
|
|
|
class _CudaKernel: |
|
|
""" |
|
|
Represents a compiled CUDA kernel that can be called with PyTorch tensors. |
|
|
""" |
|
|
|
|
|
def __init__(self, func: ctypes.c_void_p, module: ctypes.c_void_p) -> None: |
|
|
self.func = func |
|
|
self.module = module |
|
|
|
|
|
def __call__( |
|
|
self, |
|
|
grid: tuple[int, int, int] = (1, 1, 1), |
|
|
block: tuple[int, int, int] = (1, 1, 1), |
|
|
args: Optional[list] = None, |
|
|
shared_mem: int = 0, |
|
|
stream: Optional[Any] = None, |
|
|
) -> None: |
|
|
""" |
|
|
Call the compiled CUDA kernel |
|
|
|
|
|
Args: |
|
|
grid (tuple): Grid dimensions (grid_x, grid_y, grid_z) |
|
|
block (tuple): Block dimensions (block_x, block_y, block_z) |
|
|
args (list): List of arguments to pass to the kernel. |
|
|
PyTorch tensor arguments will be automatically converted to pointers. |
|
|
shared_mem (int): Shared memory size in bytes |
|
|
stream (torch.cuda.Stream): CUDA stream to use. If None, uses current stream. |
|
|
""" |
|
|
import torch |
|
|
|
|
|
libcuda = torch.cuda._utils._get_cuda_library() |
|
|
|
|
|
if not args: |
|
|
args = [] |
|
|
|
|
|
|
|
|
processed_args: list[ctypes.c_void_p] = [] |
|
|
c_args = [] |
|
|
|
|
|
for arg in args: |
|
|
if isinstance(arg, torch.Tensor): |
|
|
if not arg.is_cuda and not (arg.is_cpu and arg.is_pinned()): |
|
|
raise ValueError( |
|
|
"All tensor arguments must be CUDA tensors or pinned CPU tensors" |
|
|
) |
|
|
|
|
|
ptr = ctypes.c_void_p(arg.data_ptr()) |
|
|
processed_args.append(ptr) |
|
|
c_args.append(ctypes.byref(ptr)) |
|
|
elif isinstance(arg, int): |
|
|
|
|
|
c_int = ctypes.c_int(arg) |
|
|
|
|
|
c_args.append(ctypes.byref(c_int)) |
|
|
|
|
|
elif isinstance(arg, float): |
|
|
|
|
|
c_float = ctypes.c_float(arg) |
|
|
|
|
|
c_args.append(ctypes.byref(c_float)) |
|
|
else: |
|
|
raise TypeError(f"Unsupported argument type: {type(arg)}") |
|
|
|
|
|
|
|
|
c_args_array = (ctypes.c_void_p * len(c_args))() |
|
|
for i, arg in enumerate(c_args): |
|
|
c_args_array[i] = ctypes.cast(arg, ctypes.c_void_p) |
|
|
|
|
|
|
|
|
if stream is None: |
|
|
|
|
|
import torch.cuda |
|
|
|
|
|
stream = torch.cuda.current_stream() |
|
|
|
|
|
_check_cuda( |
|
|
libcuda.cuLaunchKernel( |
|
|
self.func, |
|
|
grid[0], |
|
|
grid[1], |
|
|
grid[2], |
|
|
block[0], |
|
|
block[1], |
|
|
block[2], |
|
|
shared_mem, |
|
|
stream._as_parameter_, |
|
|
c_args_array, |
|
|
None, |
|
|
) |
|
|
) |
|
|
|
|
|
|
|
|
def _cuda_load_module( |
|
|
ptx: Union[str, bytes], kernel_names: Optional[list[str]] = None |
|
|
) -> Union[_CudaModule, dict[str, "_CudaKernel"]]: |
|
|
""" |
|
|
Loads a CUDA module from PTX code and returns a module object that can access kernels. |
|
|
|
|
|
Args: |
|
|
ptx (bytes or str): The PTX code to load |
|
|
kernel_names (list, optional): List of kernel names to extract from the module. |
|
|
If None, will return a module object with __getattr__. |
|
|
|
|
|
Returns: |
|
|
object: If kernel_names is None, returns a module object with __getattr__ to access kernels. |
|
|
If kernel_names is provided, returns a dict mapping kernel names to _CudaKernel objects. |
|
|
""" |
|
|
|
|
|
import torch.cuda |
|
|
|
|
|
|
|
|
libcuda = _get_cuda_library() |
|
|
|
|
|
|
|
|
if isinstance(ptx, str): |
|
|
ptx = ptx.encode("utf-8") |
|
|
|
|
|
|
|
|
module = ctypes.c_void_p() |
|
|
|
|
|
stream = torch.cuda.current_stream() |
|
|
with stream: |
|
|
_check_cuda(libcuda.cuModuleLoadData(ctypes.byref(module), ptx)) |
|
|
|
|
|
if not kernel_names: |
|
|
return _CudaModule(module) |
|
|
|
|
|
|
|
|
kernels = {} |
|
|
for name in kernel_names: |
|
|
func = ctypes.c_void_p() |
|
|
_check_cuda( |
|
|
libcuda.cuModuleGetFunction( |
|
|
ctypes.byref(func), module, name.encode("utf-8") |
|
|
) |
|
|
) |
|
|
kernels[name] = _CudaKernel(func, module) |
|
|
return kernels |
|
|
|
|
|
|
|
|
def _get_device_index( |
|
|
device: Any, optional: bool = False, allow_cpu: bool = False |
|
|
) -> int: |
|
|
r"""Get the device index from :attr:`device`, which can be a torch.device object, a Python integer, or ``None``. |
|
|
|
|
|
If :attr:`device` is a torch.device object, returns the device index if it |
|
|
is a CUDA device. Note that for a CUDA device without a specified index, |
|
|
i.e., ``torch.device('cuda')``, this will return the current default CUDA |
|
|
device if :attr:`optional` is ``True``. If :attr:`allow_cpu` is ``True``, |
|
|
CPU devices will be accepted and ``-1`` will be returned in this case. |
|
|
|
|
|
If :attr:`device` is a Python integer, it is returned as is. |
|
|
|
|
|
If :attr:`device` is ``None``, this will return the current default CUDA |
|
|
device if :attr:`optional` is ``True``. |
|
|
""" |
|
|
if isinstance(device, int): |
|
|
return device |
|
|
if isinstance(device, str): |
|
|
device = torch.device(device) |
|
|
if isinstance(device, torch.device): |
|
|
if allow_cpu: |
|
|
if device.type not in ["cuda", "cpu"]: |
|
|
raise ValueError(f"Expected a cuda or cpu device, but got: {device}") |
|
|
elif device.type != "cuda": |
|
|
raise ValueError(f"Expected a cuda device, but got: {device}") |
|
|
if not torch.jit.is_scripting(): |
|
|
if isinstance(device, torch.cuda.device): |
|
|
return device.idx |
|
|
return _torch_get_device_index(device, optional, allow_cpu) |
|
|
|