File size: 6,066 Bytes
76f9669 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import ctypes
import ctypes.wintypes
import os
import struct
from cuda.pathfinder._dynamic_libs.load_dl_common import LoadedDL
from cuda.pathfinder._dynamic_libs.supported_nvidia_libs import (
LIBNAMES_REQUIRING_OS_ADD_DLL_DIRECTORY,
SUPPORTED_WINDOWS_DLLS,
)
# Mirrors WinBase.h (unfortunately not defined already elsewhere)
WINBASE_LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR = 0x00000100
WINBASE_LOAD_LIBRARY_SEARCH_DEFAULT_DIRS = 0x00001000
POINTER_ADDRESS_SPACE = 2 ** (struct.calcsize("P") * 8)
# Set up kernel32 functions with proper types
kernel32 = ctypes.windll.kernel32 # type: ignore[attr-defined]
# GetModuleHandleW
kernel32.GetModuleHandleW.argtypes = [ctypes.wintypes.LPCWSTR]
kernel32.GetModuleHandleW.restype = ctypes.wintypes.HMODULE
# LoadLibraryExW
kernel32.LoadLibraryExW.argtypes = [
ctypes.wintypes.LPCWSTR, # lpLibFileName
ctypes.wintypes.HANDLE, # hFile (reserved, must be NULL)
ctypes.wintypes.DWORD, # dwFlags
]
kernel32.LoadLibraryExW.restype = ctypes.wintypes.HMODULE
# GetModuleFileNameW
kernel32.GetModuleFileNameW.argtypes = [
ctypes.wintypes.HMODULE, # hModule
ctypes.wintypes.LPWSTR, # lpFilename
ctypes.wintypes.DWORD, # nSize
]
kernel32.GetModuleFileNameW.restype = ctypes.wintypes.DWORD
# AddDllDirectory (Windows 7+)
kernel32.AddDllDirectory.argtypes = [ctypes.wintypes.LPCWSTR]
kernel32.AddDllDirectory.restype = ctypes.c_void_p # DLL_DIRECTORY_COOKIE
def ctypes_handle_to_unsigned_int(handle: ctypes.wintypes.HMODULE) -> int:
"""Convert ctypes HMODULE to unsigned int."""
handle_uint = int(handle)
if handle_uint < 0:
# Convert from signed to unsigned representation
handle_uint += POINTER_ADDRESS_SPACE
return handle_uint
def add_dll_directory(dll_abs_path: str) -> None:
"""Add a DLL directory to the search path and update PATH environment variable.
Args:
dll_abs_path: Absolute path to the DLL file
Raises:
AssertionError: If the directory containing the DLL does not exist
"""
dirpath = os.path.dirname(dll_abs_path)
assert os.path.isdir(dirpath), dll_abs_path
# Add the DLL directory to the search path
result = kernel32.AddDllDirectory(dirpath)
if not result:
# Fallback: just update PATH if AddDllDirectory fails
pass
# Update PATH as a fallback for dependent DLL resolution
curr_path = os.environ.get("PATH")
os.environ["PATH"] = dirpath if curr_path is None else os.pathsep.join((curr_path, dirpath))
def abs_path_for_dynamic_library(libname: str, handle: ctypes.wintypes.HMODULE) -> str:
"""Get the absolute path of a loaded dynamic library on Windows."""
# Create buffer for the path
buffer = ctypes.create_unicode_buffer(260) # MAX_PATH
length = kernel32.GetModuleFileNameW(handle, buffer, len(buffer))
if length == 0:
error_code = ctypes.GetLastError() # type: ignore[attr-defined]
raise RuntimeError(f"GetModuleFileNameW failed for {libname!r} (error code: {error_code})")
# If buffer was too small, try with larger buffer
if length == len(buffer):
buffer = ctypes.create_unicode_buffer(32768) # Extended path length
length = kernel32.GetModuleFileNameW(handle, buffer, len(buffer))
if length == 0:
error_code = ctypes.GetLastError() # type: ignore[attr-defined]
raise RuntimeError(f"GetModuleFileNameW failed for {libname!r} (error code: {error_code})")
return buffer.value
def check_if_already_loaded_from_elsewhere(libname: str, have_abs_path: bool) -> LoadedDL | None:
for dll_name in SUPPORTED_WINDOWS_DLLS.get(libname, ()):
handle = kernel32.GetModuleHandleW(dll_name)
if handle:
abs_path = abs_path_for_dynamic_library(libname, handle)
if have_abs_path and libname in LIBNAMES_REQUIRING_OS_ADD_DLL_DIRECTORY:
# This is a side-effect if the pathfinder loads the library via
# load_with_abs_path(). To make the side-effect more deterministic,
# activate it even if the library was already loaded from elsewhere.
add_dll_directory(abs_path)
return LoadedDL(abs_path, True, ctypes_handle_to_unsigned_int(handle), "was-already-loaded-from-elsewhere")
return None
def load_with_system_search(libname: str) -> LoadedDL | None:
"""Try to load a DLL using system search paths.
Args:
libname: The name of the library to load
Returns:
A LoadedDL object if successful, None if the library cannot be loaded
"""
# Reverse tabulated names to achieve new → old search order.
for dll_name in reversed(SUPPORTED_WINDOWS_DLLS.get(libname, ())):
handle = kernel32.LoadLibraryExW(dll_name, None, 0)
if handle:
abs_path = abs_path_for_dynamic_library(libname, handle)
return LoadedDL(abs_path, False, ctypes_handle_to_unsigned_int(handle), "system-search")
return None
def load_with_abs_path(libname: str, found_path: str, found_via: str | None = None) -> LoadedDL:
"""Load a dynamic library from the given path.
Args:
libname: The name of the library to load
found_path: The absolute path to the DLL file
Returns:
A LoadedDL object representing the loaded library
Raises:
RuntimeError: If the DLL cannot be loaded
"""
if libname in LIBNAMES_REQUIRING_OS_ADD_DLL_DIRECTORY:
add_dll_directory(found_path)
flags = WINBASE_LOAD_LIBRARY_SEARCH_DEFAULT_DIRS | WINBASE_LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR
handle = kernel32.LoadLibraryExW(found_path, None, flags)
if not handle:
error_code = ctypes.GetLastError() # type: ignore[attr-defined]
raise RuntimeError(f"Failed to load DLL at {found_path}: Windows error {error_code}")
return LoadedDL(found_path, False, ctypes_handle_to_unsigned_int(handle), found_via)
|