File size: 7,315 Bytes
76f9669
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

from __future__ import annotations

import functools
import glob
import os
from dataclasses import dataclass

from cuda.pathfinder._headers import supported_nvidia_headers
from cuda.pathfinder._utils.env_vars import get_cuda_home_or_path
from cuda.pathfinder._utils.find_sub_dirs import find_sub_dirs_all_sitepackages
from cuda.pathfinder._utils.platform_aware import IS_WINDOWS


@dataclass
class LocatedHeaderDir:
    abs_path: str | None
    found_via: str

    def __post_init__(self) -> None:
        self.abs_path = _abs_norm(self.abs_path)


def _abs_norm(path: str | None) -> str | None:
    if path:
        return os.path.normpath(os.path.abspath(path))
    return None


def _joined_isfile(dirpath: str, basename: str) -> bool:
    return os.path.isfile(os.path.join(dirpath, basename))


def _locate_under_site_packages(sub_dir: str, h_basename: str) -> LocatedHeaderDir | None:
    # Installed from a wheel
    hdr_dir: str  # help mypy
    for hdr_dir in find_sub_dirs_all_sitepackages(tuple(sub_dir.split("/"))):
        if _joined_isfile(hdr_dir, h_basename):
            return LocatedHeaderDir(abs_path=hdr_dir, found_via="site-packages")
    return None


def _locate_based_on_ctk_layout(libname: str, h_basename: str, anchor_point: str) -> str | None:
    parts = [anchor_point]
    if libname == "nvvm":
        parts.append(libname)
    parts.append("include")
    idir = os.path.join(*parts)
    if libname == "cccl":
        if IS_WINDOWS:
            cdir_ctk12 = os.path.join(idir, "targets", "x64")  # conda has this anomaly
            cdir_ctk13 = os.path.join(cdir_ctk12, "cccl")
            if _joined_isfile(cdir_ctk13, h_basename):
                return cdir_ctk13
            if _joined_isfile(cdir_ctk12, h_basename):
                return cdir_ctk12
        cdir = os.path.join(idir, "cccl")  # CTK 13
        if _joined_isfile(cdir, h_basename):
            return cdir
    if _joined_isfile(idir, h_basename):
        return idir
    return None


def _find_based_on_conda_layout(libname: str, h_basename: str, ctk_layout: bool) -> LocatedHeaderDir | None:
    conda_prefix = os.environ.get("CONDA_PREFIX")
    if not conda_prefix:
        return None
    if IS_WINDOWS:
        anchor_point = os.path.join(conda_prefix, "Library")
        if not os.path.isdir(anchor_point):
            return None
    else:
        if ctk_layout:
            targets_include_path = glob.glob(os.path.join(conda_prefix, "targets", "*", "include"))
            if not targets_include_path:
                return None
            if len(targets_include_path) != 1:
                # Conda does not support multiple architectures.
                # QUESTION(PR#956): Do we want to issue a warning?
                return None
            include_path = targets_include_path[0]
        else:
            include_path = os.path.join(conda_prefix, "include")
        anchor_point = os.path.dirname(include_path)
    found_header_path = _locate_based_on_ctk_layout(libname, h_basename, anchor_point)
    if found_header_path:
        return LocatedHeaderDir(abs_path=found_header_path, found_via="conda")
    return None


def _find_ctk_header_directory(libname: str) -> LocatedHeaderDir | None:
    h_basename = supported_nvidia_headers.SUPPORTED_HEADERS_CTK[libname]
    candidate_dirs = supported_nvidia_headers.SUPPORTED_SITE_PACKAGE_HEADER_DIRS_CTK[libname]

    for cdir in candidate_dirs:
        if hdr_dir := _locate_under_site_packages(cdir, h_basename):
            return hdr_dir

    if hdr_dir := _find_based_on_conda_layout(libname, h_basename, True):
        return hdr_dir

    cuda_home = get_cuda_home_or_path()
    if cuda_home:  # noqa: SIM102
        if result := _locate_based_on_ctk_layout(libname, h_basename, cuda_home):
            return LocatedHeaderDir(abs_path=result, found_via="CUDA_HOME")

    return None


@functools.cache
def locate_nvidia_header_directory(libname: str) -> LocatedHeaderDir | None:
    """Locate the header directory for a supported NVIDIA library.

    Args:
        libname (str): The short name of the library whose headers are needed
            (e.g., ``"nvrtc"``, ``"cusolver"``, ``"nvshmem"``).

    Returns:
        LocatedHeaderDir or None: A LocatedHeaderDir object containing the absolute path
        to the discovered header directory and information about where it was found,
        or ``None`` if the headers cannot be found.

    Raises:
        RuntimeError: If ``libname`` is not in the supported set.

    Search order:
        1. **NVIDIA Python wheels**

           - Scan installed distributions (``site-packages``) for header layouts
             shipped in NVIDIA wheels (e.g., ``cuda-toolkit[nvrtc]``).

        2. **Conda environments**

           - Check Conda-style installation prefixes, which use platform-specific
             include directory layouts.

        3. **CUDA Toolkit environment variables**

           - Use ``CUDA_HOME`` or ``CUDA_PATH`` (in that order).
    """

    if libname in supported_nvidia_headers.SUPPORTED_HEADERS_CTK:
        return _find_ctk_header_directory(libname)

    h_basename = supported_nvidia_headers.SUPPORTED_HEADERS_NON_CTK.get(libname)
    if h_basename is None:
        raise RuntimeError(f"UNKNOWN {libname=}")

    candidate_dirs = supported_nvidia_headers.SUPPORTED_SITE_PACKAGE_HEADER_DIRS_NON_CTK.get(libname, [])

    for cdir in candidate_dirs:
        if found_hdr := _locate_under_site_packages(cdir, h_basename):
            return found_hdr

    if found_hdr := _find_based_on_conda_layout(libname, h_basename, False):
        return found_hdr

    # Fall back to system install directories
    candidate_dirs = supported_nvidia_headers.SUPPORTED_INSTALL_DIRS_NON_CTK.get(libname, [])
    for cdir in candidate_dirs:
        for hdr_dir in sorted(glob.glob(cdir), reverse=True):
            if _joined_isfile(hdr_dir, h_basename):
                # For system installs, we don't have a clear found_via, so use "system"
                return LocatedHeaderDir(abs_path=hdr_dir, found_via="supported_install_dir")
    return None


def find_nvidia_header_directory(libname: str) -> str | None:
    """Locate the header directory for a supported NVIDIA library.

    Args:
        libname (str): The short name of the library whose headers are needed
            (e.g., ``"nvrtc"``, ``"cusolver"``, ``"nvshmem"``).

    Returns:
        str or None: Absolute path to the discovered header directory, or ``None``
        if the headers cannot be found.

    Raises:
        RuntimeError: If ``libname`` is not in the supported set.

    Search order:
        1. **NVIDIA Python wheels**

           - Scan installed distributions (``site-packages``) for header layouts
             shipped in NVIDIA wheels (e.g., ``cuda-toolkit[nvrtc]``).

        2. **Conda environments**

           - Check Conda-style installation prefixes, which use platform-specific
             include directory layouts.

        3. **CUDA Toolkit environment variables**

           - Use ``CUDA_HOME`` or ``CUDA_PATH`` (in that order).
    """
    found = locate_nvidia_header_directory(libname)
    return found.abs_path if found else None