File size: 8,010 Bytes
76f9669 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import contextlib
import ctypes
import ctypes.util
import os
from typing import cast
from cuda.pathfinder._dynamic_libs.load_dl_common import LoadedDL
from cuda.pathfinder._dynamic_libs.supported_nvidia_libs import (
LIBNAMES_REQUIRING_RTLD_DEEPBIND,
SUPPORTED_LINUX_SONAMES,
)
CDLL_MODE = os.RTLD_NOW | os.RTLD_GLOBAL
def _load_libdl() -> ctypes.CDLL:
# In normal glibc-based Linux environments, find_library("dl") should return
# something like "libdl.so.2". In minimal or stripped-down environments
# (no ldconfig/gcc, incomplete linker cache), this can return None even
# though libdl is present. In that case, we fall back to the stable SONAME.
name = ctypes.util.find_library("dl") or "libdl.so.2"
try:
return ctypes.CDLL(name)
except OSError as e:
raise RuntimeError(f"Could not load {name!r} (required for dlinfo/dlerror on Linux)") from e
LIBDL = _load_libdl()
# dlinfo
LIBDL.dlinfo.argtypes = [ctypes.c_void_p, ctypes.c_int, ctypes.c_void_p]
LIBDL.dlinfo.restype = ctypes.c_int
# dlerror (thread-local error string; cleared after read)
LIBDL.dlerror.argtypes = []
LIBDL.dlerror.restype = ctypes.c_char_p
# First appeared in 2004-era glibc. Universally correct on Linux for all practical purposes.
RTLD_DI_LINKMAP = 2
RTLD_DI_ORIGIN = 6
class _LinkMapLNameView(ctypes.Structure):
"""
Prefix-only view of glibc's `struct link_map` used **solely** to read `l_name`.
Background:
- `dlinfo(handle, RTLD_DI_LINKMAP, ...)` returns a `struct link_map*`.
- The first few members of `struct link_map` (including `l_name`) have been
stable on glibc for decades and are documented as debugger-visible.
- We only need the offset/layout of `l_name`, not the full struct.
Safety constraints:
- This is a **partial** definition (prefix). It must only be used via a pointer
returned by `dlinfo(...)`.
- Do **not** instantiate it or pass it **by value** to any C function.
- Do **not** access any members beyond those declared here.
- Do **not** rely on `ctypes.sizeof(LinkMapPrefix)` for allocation.
Rationale:
- Defining only the leading fields avoids depending on internal/unstable
tail members while keeping code more readable than raw pointer arithmetic.
"""
_fields_ = (
("l_addr", ctypes.c_void_p), # ElfW(Addr)
("l_name", ctypes.c_char_p), # char*
)
# Defensive assertions, mainly to document the invariants we depend on
assert _LinkMapLNameView.l_addr.offset == 0
assert _LinkMapLNameView.l_name.offset == ctypes.sizeof(ctypes.c_void_p)
def _dl_last_error() -> str | None:
msg_bytes = cast(bytes | None, LIBDL.dlerror())
if not msg_bytes:
return None # no pending error
# Never raises; undecodable bytes are mapped to U+DC80..U+DCFF
return msg_bytes.decode("utf-8", "surrogateescape")
def l_name_for_dynamic_library(libname: str, handle: ctypes.CDLL) -> str:
lm_view = ctypes.POINTER(_LinkMapLNameView)()
rc = LIBDL.dlinfo(ctypes.c_void_p(handle._handle), RTLD_DI_LINKMAP, ctypes.byref(lm_view))
if rc != 0:
err = _dl_last_error()
raise OSError(f"dlinfo failed for {libname=!r} (rc={rc})" + (f": {err}" if err else ""))
if not lm_view: # NULL link_map**
raise OSError(f"dlinfo returned NULL link_map pointer for {libname=!r}")
l_name_bytes = lm_view.contents.l_name
if not l_name_bytes:
raise OSError(f"dlinfo returned empty link_map->l_name for {libname=!r}")
path = os.fsdecode(l_name_bytes)
if not path:
raise OSError(f"dlinfo returned empty l_name string for {libname=!r}")
return path
def l_origin_for_dynamic_library(libname: str, handle: ctypes.CDLL) -> str:
l_origin_buf = ctypes.create_string_buffer(4096)
rc = LIBDL.dlinfo(ctypes.c_void_p(handle._handle), RTLD_DI_ORIGIN, l_origin_buf)
if rc != 0:
err = _dl_last_error()
raise OSError(f"dlinfo failed for {libname=!r} (rc={rc})" + (f": {err}" if err else ""))
path = os.fsdecode(l_origin_buf.value)
if not path:
raise OSError(f"dlinfo returned empty l_origin string for {libname=!r}")
return path
def abs_path_for_dynamic_library(libname: str, handle: ctypes.CDLL) -> str:
l_name = l_name_for_dynamic_library(libname, handle)
l_origin = l_origin_for_dynamic_library(libname, handle)
return os.path.join(l_origin, os.path.basename(l_name))
def get_candidate_sonames(libname: str) -> list[str]:
# Reverse tabulated names to achieve new → old search order.
candidate_sonames = list(reversed(SUPPORTED_LINUX_SONAMES.get(libname, ())))
candidate_sonames.append(f"lib{libname}.so")
return candidate_sonames
def check_if_already_loaded_from_elsewhere(libname: str, _have_abs_path: bool) -> LoadedDL | None:
for soname in get_candidate_sonames(libname):
try:
handle = ctypes.CDLL(soname, mode=os.RTLD_NOLOAD)
except OSError:
continue
else:
return LoadedDL(
abs_path_for_dynamic_library(libname, handle), True, handle._handle, "was-already-loaded-from-elsewhere"
)
return None
def _load_lib(libname: str, filename: str) -> ctypes.CDLL:
cdll_mode = CDLL_MODE
if libname in LIBNAMES_REQUIRING_RTLD_DEEPBIND:
cdll_mode |= os.RTLD_DEEPBIND
return ctypes.CDLL(filename, cdll_mode)
def load_with_system_search(libname: str) -> LoadedDL | None:
"""Try to load a library using system search paths.
Args:
libname: The name of the library to load
Returns:
A LoadedDL object if successful, None if the library cannot be loaded
Raises:
RuntimeError: If the library is loaded but no expected symbol is found
"""
for soname in get_candidate_sonames(libname):
try:
handle = _load_lib(libname, soname)
except OSError:
pass
else:
abs_path = abs_path_for_dynamic_library(libname, handle)
if abs_path is None:
raise RuntimeError(f"No expected symbol for {libname=!r}")
return LoadedDL(abs_path, False, handle._handle, "system-search")
return None
def _work_around_known_bugs(libname: str, found_path: str) -> None:
if libname == "nvrtc":
# Work around bug/oversight in
# nvidia_cuda_nvrtc-13.0.48-py3-none-manylinux2010_x86_64.manylinux_2_12_x86_64.whl
# Issue: libnvrtc.so.13 RUNPATH is not set.
# This workaround is highly specific
# - for simplicity.
# - to not mask bugs in future nvidia-cuda-nvrtc releases.
# - because a more general workaround is complicated.
dirname, basename = os.path.split(found_path)
if basename == "libnvrtc.so.13":
dep_basename = "libnvrtc-builtins.so.13.0"
dep_path = os.path.join(dirname, dep_basename)
if os.path.isfile(dep_path):
# In case of failure, defer to primary load, which is almost certain to fail, too.
with contextlib.suppress(OSError):
ctypes.CDLL(dep_path, CDLL_MODE)
def load_with_abs_path(libname: str, found_path: str, found_via: str | None = None) -> LoadedDL:
"""Load a dynamic library from the given path.
Args:
libname: The name of the library to load
found_path: The absolute path to the library file
Returns:
A LoadedDL object representing the loaded library
Raises:
RuntimeError: If the library cannot be loaded
"""
_work_around_known_bugs(libname, found_path)
try:
handle = _load_lib(libname, found_path)
except OSError as e:
raise RuntimeError(f"Failed to dlopen {found_path}: {e}") from e
return LoadedDL(found_path, False, handle._handle, found_via)
|