File size: 8,010 Bytes
76f9669
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

import contextlib
import ctypes
import ctypes.util
import os
from typing import cast

from cuda.pathfinder._dynamic_libs.load_dl_common import LoadedDL
from cuda.pathfinder._dynamic_libs.supported_nvidia_libs import (
    LIBNAMES_REQUIRING_RTLD_DEEPBIND,
    SUPPORTED_LINUX_SONAMES,
)

CDLL_MODE = os.RTLD_NOW | os.RTLD_GLOBAL


def _load_libdl() -> ctypes.CDLL:
    # In normal glibc-based Linux environments, find_library("dl") should return
    # something like "libdl.so.2". In minimal or stripped-down environments
    # (no ldconfig/gcc, incomplete linker cache), this can return None even
    # though libdl is present. In that case, we fall back to the stable SONAME.
    name = ctypes.util.find_library("dl") or "libdl.so.2"
    try:
        return ctypes.CDLL(name)
    except OSError as e:
        raise RuntimeError(f"Could not load {name!r} (required for dlinfo/dlerror on Linux)") from e


LIBDL = _load_libdl()

# dlinfo
LIBDL.dlinfo.argtypes = [ctypes.c_void_p, ctypes.c_int, ctypes.c_void_p]
LIBDL.dlinfo.restype = ctypes.c_int

# dlerror (thread-local error string; cleared after read)
LIBDL.dlerror.argtypes = []
LIBDL.dlerror.restype = ctypes.c_char_p

# First appeared in 2004-era glibc. Universally correct on Linux for all practical purposes.
RTLD_DI_LINKMAP = 2
RTLD_DI_ORIGIN = 6


class _LinkMapLNameView(ctypes.Structure):
    """
    Prefix-only view of glibc's `struct link_map` used **solely** to read `l_name`.

    Background:
      - `dlinfo(handle, RTLD_DI_LINKMAP, ...)` returns a `struct link_map*`.
      - The first few members of `struct link_map` (including `l_name`) have been
        stable on glibc for decades and are documented as debugger-visible.
      - We only need the offset/layout of `l_name`, not the full struct.

    Safety constraints:
      - This is a **partial** definition (prefix). It must only be used via a pointer
        returned by `dlinfo(...)`.
      - Do **not** instantiate it or pass it **by value** to any C function.
      - Do **not** access any members beyond those declared here.
      - Do **not** rely on `ctypes.sizeof(LinkMapPrefix)` for allocation.

    Rationale:
      - Defining only the leading fields avoids depending on internal/unstable
        tail members while keeping code more readable than raw pointer arithmetic.
    """

    _fields_ = (
        ("l_addr", ctypes.c_void_p),  # ElfW(Addr)
        ("l_name", ctypes.c_char_p),  # char*
    )


# Defensive assertions, mainly  to document the invariants we depend on
assert _LinkMapLNameView.l_addr.offset == 0
assert _LinkMapLNameView.l_name.offset == ctypes.sizeof(ctypes.c_void_p)


def _dl_last_error() -> str | None:
    msg_bytes = cast(bytes | None, LIBDL.dlerror())
    if not msg_bytes:
        return None  # no pending error
    # Never raises; undecodable bytes are mapped to U+DC80..U+DCFF
    return msg_bytes.decode("utf-8", "surrogateescape")


def l_name_for_dynamic_library(libname: str, handle: ctypes.CDLL) -> str:
    lm_view = ctypes.POINTER(_LinkMapLNameView)()
    rc = LIBDL.dlinfo(ctypes.c_void_p(handle._handle), RTLD_DI_LINKMAP, ctypes.byref(lm_view))
    if rc != 0:
        err = _dl_last_error()
        raise OSError(f"dlinfo failed for {libname=!r} (rc={rc})" + (f": {err}" if err else ""))
    if not lm_view:  # NULL link_map**
        raise OSError(f"dlinfo returned NULL link_map pointer for {libname=!r}")

    l_name_bytes = lm_view.contents.l_name
    if not l_name_bytes:
        raise OSError(f"dlinfo returned empty link_map->l_name for {libname=!r}")

    path = os.fsdecode(l_name_bytes)
    if not path:
        raise OSError(f"dlinfo returned empty l_name string for {libname=!r}")

    return path


def l_origin_for_dynamic_library(libname: str, handle: ctypes.CDLL) -> str:
    l_origin_buf = ctypes.create_string_buffer(4096)
    rc = LIBDL.dlinfo(ctypes.c_void_p(handle._handle), RTLD_DI_ORIGIN, l_origin_buf)
    if rc != 0:
        err = _dl_last_error()
        raise OSError(f"dlinfo failed for {libname=!r} (rc={rc})" + (f": {err}" if err else ""))

    path = os.fsdecode(l_origin_buf.value)
    if not path:
        raise OSError(f"dlinfo returned empty l_origin string for {libname=!r}")

    return path


def abs_path_for_dynamic_library(libname: str, handle: ctypes.CDLL) -> str:
    l_name = l_name_for_dynamic_library(libname, handle)
    l_origin = l_origin_for_dynamic_library(libname, handle)
    return os.path.join(l_origin, os.path.basename(l_name))


def get_candidate_sonames(libname: str) -> list[str]:
    # Reverse tabulated names to achieve new → old search order.
    candidate_sonames = list(reversed(SUPPORTED_LINUX_SONAMES.get(libname, ())))
    candidate_sonames.append(f"lib{libname}.so")
    return candidate_sonames


def check_if_already_loaded_from_elsewhere(libname: str, _have_abs_path: bool) -> LoadedDL | None:
    for soname in get_candidate_sonames(libname):
        try:
            handle = ctypes.CDLL(soname, mode=os.RTLD_NOLOAD)
        except OSError:
            continue
        else:
            return LoadedDL(
                abs_path_for_dynamic_library(libname, handle), True, handle._handle, "was-already-loaded-from-elsewhere"
            )
    return None


def _load_lib(libname: str, filename: str) -> ctypes.CDLL:
    cdll_mode = CDLL_MODE
    if libname in LIBNAMES_REQUIRING_RTLD_DEEPBIND:
        cdll_mode |= os.RTLD_DEEPBIND
    return ctypes.CDLL(filename, cdll_mode)


def load_with_system_search(libname: str) -> LoadedDL | None:
    """Try to load a library using system search paths.

    Args:
        libname: The name of the library to load

    Returns:
        A LoadedDL object if successful, None if the library cannot be loaded

    Raises:
        RuntimeError: If the library is loaded but no expected symbol is found
    """
    for soname in get_candidate_sonames(libname):
        try:
            handle = _load_lib(libname, soname)
        except OSError:
            pass
        else:
            abs_path = abs_path_for_dynamic_library(libname, handle)
            if abs_path is None:
                raise RuntimeError(f"No expected symbol for {libname=!r}")
            return LoadedDL(abs_path, False, handle._handle, "system-search")
    return None


def _work_around_known_bugs(libname: str, found_path: str) -> None:
    if libname == "nvrtc":
        # Work around bug/oversight in
        #   nvidia_cuda_nvrtc-13.0.48-py3-none-manylinux2010_x86_64.manylinux_2_12_x86_64.whl
        # Issue: libnvrtc.so.13 RUNPATH is not set.
        # This workaround is highly specific
        #   - for simplicity.
        #   - to not mask bugs in future nvidia-cuda-nvrtc releases.
        #   - because a more general workaround is complicated.
        dirname, basename = os.path.split(found_path)
        if basename == "libnvrtc.so.13":
            dep_basename = "libnvrtc-builtins.so.13.0"
            dep_path = os.path.join(dirname, dep_basename)
            if os.path.isfile(dep_path):
                # In case of failure, defer to primary load, which is almost certain to fail, too.
                with contextlib.suppress(OSError):
                    ctypes.CDLL(dep_path, CDLL_MODE)


def load_with_abs_path(libname: str, found_path: str, found_via: str | None = None) -> LoadedDL:
    """Load a dynamic library from the given path.

    Args:
        libname: The name of the library to load
        found_path: The absolute path to the library file

    Returns:
        A LoadedDL object representing the loaded library

    Raises:
        RuntimeError: If the library cannot be loaded
    """
    _work_around_known_bugs(libname, found_path)
    try:
        handle = _load_lib(libname, found_path)
    except OSError as e:
        raise RuntimeError(f"Failed to dlopen {found_path}: {e}") from e
    return LoadedDL(found_path, False, handle._handle, found_via)