| | |
| | |
| |
|
| | import ctypes |
| | import ctypes.wintypes |
| | import os |
| | import struct |
| |
|
| | from cuda.pathfinder._dynamic_libs.load_dl_common import LoadedDL |
| | from cuda.pathfinder._dynamic_libs.supported_nvidia_libs import ( |
| | LIBNAMES_REQUIRING_OS_ADD_DLL_DIRECTORY, |
| | SUPPORTED_WINDOWS_DLLS, |
| | ) |
| |
|
| | |
| | WINBASE_LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR = 0x00000100 |
| | WINBASE_LOAD_LIBRARY_SEARCH_DEFAULT_DIRS = 0x00001000 |
| |
|
| | POINTER_ADDRESS_SPACE = 2 ** (struct.calcsize("P") * 8) |
| |
|
| | |
| | kernel32 = ctypes.windll.kernel32 |
| |
|
| | |
| | kernel32.GetModuleHandleW.argtypes = [ctypes.wintypes.LPCWSTR] |
| | kernel32.GetModuleHandleW.restype = ctypes.wintypes.HMODULE |
| |
|
| | |
| | kernel32.LoadLibraryExW.argtypes = [ |
| | ctypes.wintypes.LPCWSTR, |
| | ctypes.wintypes.HANDLE, |
| | ctypes.wintypes.DWORD, |
| | ] |
| | kernel32.LoadLibraryExW.restype = ctypes.wintypes.HMODULE |
| |
|
| | |
| | kernel32.GetModuleFileNameW.argtypes = [ |
| | ctypes.wintypes.HMODULE, |
| | ctypes.wintypes.LPWSTR, |
| | ctypes.wintypes.DWORD, |
| | ] |
| | kernel32.GetModuleFileNameW.restype = ctypes.wintypes.DWORD |
| |
|
| | |
| | kernel32.AddDllDirectory.argtypes = [ctypes.wintypes.LPCWSTR] |
| | kernel32.AddDllDirectory.restype = ctypes.c_void_p |
| |
|
| |
|
| | def ctypes_handle_to_unsigned_int(handle: ctypes.wintypes.HMODULE) -> int: |
| | """Convert ctypes HMODULE to unsigned int.""" |
| | handle_uint = int(handle) |
| | if handle_uint < 0: |
| | |
| | handle_uint += POINTER_ADDRESS_SPACE |
| | return handle_uint |
| |
|
| |
|
| | def add_dll_directory(dll_abs_path: str) -> None: |
| | """Add a DLL directory to the search path and update PATH environment variable. |
| | |
| | Args: |
| | dll_abs_path: Absolute path to the DLL file |
| | |
| | Raises: |
| | AssertionError: If the directory containing the DLL does not exist |
| | """ |
| | dirpath = os.path.dirname(dll_abs_path) |
| | assert os.path.isdir(dirpath), dll_abs_path |
| |
|
| | |
| | result = kernel32.AddDllDirectory(dirpath) |
| | if not result: |
| | |
| | pass |
| |
|
| | |
| | curr_path = os.environ.get("PATH") |
| | os.environ["PATH"] = dirpath if curr_path is None else os.pathsep.join((curr_path, dirpath)) |
| |
|
| |
|
| | def abs_path_for_dynamic_library(libname: str, handle: ctypes.wintypes.HMODULE) -> str: |
| | """Get the absolute path of a loaded dynamic library on Windows.""" |
| | |
| | buffer = ctypes.create_unicode_buffer(260) |
| | length = kernel32.GetModuleFileNameW(handle, buffer, len(buffer)) |
| |
|
| | if length == 0: |
| | error_code = ctypes.GetLastError() |
| | raise RuntimeError(f"GetModuleFileNameW failed for {libname!r} (error code: {error_code})") |
| |
|
| | |
| | if length == len(buffer): |
| | buffer = ctypes.create_unicode_buffer(32768) |
| | length = kernel32.GetModuleFileNameW(handle, buffer, len(buffer)) |
| | if length == 0: |
| | error_code = ctypes.GetLastError() |
| | raise RuntimeError(f"GetModuleFileNameW failed for {libname!r} (error code: {error_code})") |
| |
|
| | return buffer.value |
| |
|
| |
|
| | def check_if_already_loaded_from_elsewhere(libname: str, have_abs_path: bool) -> LoadedDL | None: |
| | for dll_name in SUPPORTED_WINDOWS_DLLS.get(libname, ()): |
| | handle = kernel32.GetModuleHandleW(dll_name) |
| | if handle: |
| | abs_path = abs_path_for_dynamic_library(libname, handle) |
| | if have_abs_path and libname in LIBNAMES_REQUIRING_OS_ADD_DLL_DIRECTORY: |
| | |
| | |
| | |
| | add_dll_directory(abs_path) |
| | return LoadedDL(abs_path, True, ctypes_handle_to_unsigned_int(handle), "was-already-loaded-from-elsewhere") |
| | return None |
| |
|
| |
|
| | def load_with_system_search(libname: str) -> LoadedDL | None: |
| | """Try to load a DLL using system search paths. |
| | |
| | Args: |
| | libname: The name of the library to load |
| | |
| | Returns: |
| | A LoadedDL object if successful, None if the library cannot be loaded |
| | """ |
| | |
| | for dll_name in reversed(SUPPORTED_WINDOWS_DLLS.get(libname, ())): |
| | handle = kernel32.LoadLibraryExW(dll_name, None, 0) |
| | if handle: |
| | abs_path = abs_path_for_dynamic_library(libname, handle) |
| | return LoadedDL(abs_path, False, ctypes_handle_to_unsigned_int(handle), "system-search") |
| |
|
| | return None |
| |
|
| |
|
| | def load_with_abs_path(libname: str, found_path: str, found_via: str | None = None) -> LoadedDL: |
| | """Load a dynamic library from the given path. |
| | |
| | Args: |
| | libname: The name of the library to load |
| | found_path: The absolute path to the DLL file |
| | |
| | Returns: |
| | A LoadedDL object representing the loaded library |
| | |
| | Raises: |
| | RuntimeError: If the DLL cannot be loaded |
| | """ |
| | if libname in LIBNAMES_REQUIRING_OS_ADD_DLL_DIRECTORY: |
| | add_dll_directory(found_path) |
| |
|
| | flags = WINBASE_LOAD_LIBRARY_SEARCH_DEFAULT_DIRS | WINBASE_LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR |
| | handle = kernel32.LoadLibraryExW(found_path, None, flags) |
| |
|
| | if not handle: |
| | error_code = ctypes.GetLastError() |
| | raise RuntimeError(f"Failed to load DLL at {found_path}: Windows error {error_code}") |
| |
|
| | return LoadedDL(found_path, False, ctypes_handle_to_unsigned_int(handle), found_via) |
| |
|