| |
| |
|
|
| import functools |
| import os |
| from dataclasses import dataclass |
| from typing import NoReturn, TypedDict |
|
|
| from cuda.pathfinder._utils.env_vars import get_cuda_home_or_path |
| from cuda.pathfinder._utils.find_sub_dirs import find_sub_dirs_all_sitepackages |
| from cuda.pathfinder._utils.platform_aware import IS_WINDOWS |
|
|
|
|
| class BitcodeLibNotFoundError(RuntimeError): |
| """Raised when a bitcode library cannot be found.""" |
|
|
|
|
| @dataclass(frozen=True) |
| class LocatedBitcodeLib: |
| """Information about a located bitcode library.""" |
|
|
| name: str |
| abs_path: str |
| filename: str |
| found_via: str |
|
|
|
|
| class _BitcodeLibInfo(TypedDict): |
| filename: str |
| rel_path: str |
| site_packages_dirs: tuple[str, ...] |
|
|
|
|
| _SUPPORTED_BITCODE_LIBS_INFO: dict[str, _BitcodeLibInfo] = { |
| "device": { |
| "filename": "libdevice.10.bc", |
| "rel_path": os.path.join("nvvm", "libdevice"), |
| "site_packages_dirs": ( |
| "nvidia/cu13/nvvm/libdevice", |
| "nvidia/cuda_nvcc/nvvm/libdevice", |
| ), |
| }, |
| } |
|
|
| |
| SUPPORTED_BITCODE_LIBS: tuple[str, ...] = tuple(sorted(_SUPPORTED_BITCODE_LIBS_INFO.keys())) |
|
|
|
|
| def _no_such_file_in_dir(dir_path: str, filename: str, error_messages: list[str], attachments: list[str]) -> None: |
| error_messages.append(f"No such file: {os.path.join(dir_path, filename)}") |
| if os.path.isdir(dir_path): |
| attachments.append(f' listdir("{dir_path}"):') |
| for node in sorted(os.listdir(dir_path)): |
| attachments.append(f" {node}") |
| else: |
| attachments.append(f' Directory does not exist: "{dir_path}"') |
|
|
|
|
| class _FindBitcodeLib: |
| def __init__(self, name: str) -> None: |
| if name not in _SUPPORTED_BITCODE_LIBS_INFO: |
| raise ValueError(f"Unknown bitcode library: '{name}'. Supported: {', '.join(SUPPORTED_BITCODE_LIBS)}") |
| self.name: str = name |
| self.config: _BitcodeLibInfo = _SUPPORTED_BITCODE_LIBS_INFO[name] |
| self.filename: str = self.config["filename"] |
| self.rel_path: str = self.config["rel_path"] |
| self.site_packages_dirs: tuple[str, ...] = self.config["site_packages_dirs"] |
| self.error_messages: list[str] = [] |
| self.attachments: list[str] = [] |
|
|
| def try_site_packages(self) -> str | None: |
| for rel_dir in self.site_packages_dirs: |
| sub_dir = tuple(rel_dir.split("/")) |
| for abs_dir in find_sub_dirs_all_sitepackages(sub_dir): |
| file_path = os.path.join(abs_dir, self.filename) |
| if os.path.isfile(file_path): |
| return file_path |
| return None |
|
|
| def try_with_conda_prefix(self) -> str | None: |
| conda_prefix = os.environ.get("CONDA_PREFIX") |
| if not conda_prefix: |
| return None |
|
|
| anchor = os.path.join(conda_prefix, "Library") if IS_WINDOWS else conda_prefix |
| file_path = os.path.join(anchor, self.rel_path, self.filename) |
| if os.path.isfile(file_path): |
| return file_path |
| return None |
|
|
| def try_with_cuda_home(self) -> str | None: |
| cuda_home = get_cuda_home_or_path() |
| if cuda_home is None: |
| self.error_messages.append("CUDA_HOME/CUDA_PATH not set") |
| return None |
|
|
| file_path = os.path.join(cuda_home, self.rel_path, self.filename) |
| if os.path.isfile(file_path): |
| return file_path |
|
|
| _no_such_file_in_dir( |
| os.path.join(cuda_home, self.rel_path), |
| self.filename, |
| self.error_messages, |
| self.attachments, |
| ) |
| return None |
|
|
| def raise_not_found_error(self) -> NoReturn: |
| err = ", ".join(self.error_messages) if self.error_messages else "No search paths available" |
| att = "\n".join(self.attachments) if self.attachments else "" |
| raise BitcodeLibNotFoundError(f'Failure finding "{self.filename}": {err}\n{att}') |
|
|
|
|
| def locate_bitcode_lib(name: str) -> LocatedBitcodeLib: |
| """Locate a bitcode library by name. |
| |
| Raises: |
| ValueError: If ``name`` is not a supported bitcode library. |
| BitcodeLibNotFoundError: If the bitcode library cannot be found. |
| """ |
| finder = _FindBitcodeLib(name) |
|
|
| abs_path = finder.try_site_packages() |
| if abs_path is not None: |
| return LocatedBitcodeLib( |
| name=name, |
| abs_path=abs_path, |
| filename=finder.filename, |
| found_via="site-packages", |
| ) |
|
|
| abs_path = finder.try_with_conda_prefix() |
| if abs_path is not None: |
| return LocatedBitcodeLib( |
| name=name, |
| abs_path=abs_path, |
| filename=finder.filename, |
| found_via="conda", |
| ) |
|
|
| abs_path = finder.try_with_cuda_home() |
| if abs_path is not None: |
| return LocatedBitcodeLib( |
| name=name, |
| abs_path=abs_path, |
| filename=finder.filename, |
| found_via="CUDA_HOME", |
| ) |
|
|
| finder.raise_not_found_error() |
|
|
|
|
| @functools.cache |
| def find_bitcode_lib(name: str) -> str: |
| """Find the absolute path to a bitcode library. |
| |
| Raises: |
| ValueError: If ``name`` is not a supported bitcode library. |
| BitcodeLibNotFoundError: If the bitcode library cannot be found. |
| """ |
| return locate_bitcode_lib(name).abs_path |
|
|