| import os | |
| from pathlib import Path | |
| import sys | |
| import platform | |
| def get_cuda_ver_from_dir(cuda_home): | |
| nvrtc = filter(lambda lib_file: "nvrtc-builtins" in lib_file, os.listdir(cuda_home)) | |
| nvrtc = list(nvrtc) | |
| if len(nvrtc) == 0: | |
| return | |
| nvrtc = nvrtc[0] | |
| if ('102' in nvrtc) or ('10.2' in nvrtc): | |
| return '102' | |
| if '110' in nvrtc or ('11.0' in nvrtc): | |
| return '110' | |
| if '111' in nvrtc or ('11.1' in nvrtc): | |
| return '111' | |
| if '11' in nvrtc: | |
| return '11x' | |
| if '12' in nvrtc: | |
| return '12x' | |
| s_param = '-s' if "python_embeded" in sys.executable else '' | |
| def get_cuda_home_path(): | |
| if "CUDA_HOME" in os.environ: | |
| return os.environ["CUDA_HOME"] | |
| import torch | |
| torch_lib_path = Path(torch.__file__).parent / "lib" | |
| torch_lib_path = str(torch_lib_path.resolve()) | |
| if os.path.exists(torch_lib_path): | |
| nvrtc = filter(lambda lib_file: "nvrtc-builtins" in lib_file, os.listdir(torch_lib_path)) | |
| nvrtc = list(nvrtc) | |
| return torch_lib_path if len(nvrtc) > 0 else None | |
| def install_cupy(): | |
| cuda_home = get_cuda_home_path() | |
| try: | |
| if cuda_home is not None: | |
| os.environ["CUDA_HOME"] = cuda_home | |
| os.environ["CUDA_PATH"] = cuda_home | |
| import cupy | |
| print("CuPy is already installed.") | |
| except: | |
| print("Uninstall cupy if existed...") | |
| os.system(f'"{sys.executable}" {s_param} -m pip uninstall -y cupy-wheel cupy-cuda102 cupy-cuda110 cupy-cuda111 cupy-cuda11x cupy-cuda12x') | |
| print("Installing cupy...") | |
| cuda_ver = get_cuda_ver_from_dir(cuda_home) | |
| cupy_package = f"cupy-cuda{cuda_ver}" if cuda_ver is not None else "cupy-wheel" | |
| os.system(f'"{sys.executable}" {s_param} -m pip install {cupy_package}') | |
| with open(Path(__file__).parent / "requirements-no-cupy.txt", 'r') as f: | |
| for package in f.readlines(): | |
| package = package.strip() | |
| print(f"Installing {package}...") | |
| os.system(f'"{sys.executable}" {s_param} -m pip install {package}') | |
| print("Checking cupy...") | |
| install_cupy() |