| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import os |
| | import glob |
| | import torch |
| | import torch.utils.cpp_extension |
| | import importlib |
| | import hashlib |
| | import shutil |
| | from pathlib import Path |
| |
|
| | from torch.utils.file_baton import FileBaton |
| |
|
| | |
| | |
| |
|
| | verbosity = 'brief' |
| |
|
| | |
| | |
| |
|
| | def _find_compiler_bindir(): |
| | patterns = [ |
| | 'C:/Program Files (x86)/Microsoft Visual Studio/*/Professional/VC/Tools/MSVC/*/bin/Hostx64/x64', |
| | 'C:/Program Files (x86)/Microsoft Visual Studio/*/BuildTools/VC/Tools/MSVC/*/bin/Hostx64/x64', |
| | 'C:/Program Files (x86)/Microsoft Visual Studio/*/Community/VC/Tools/MSVC/*/bin/Hostx64/x64', |
| | 'C:/Program Files (x86)/Microsoft Visual Studio */vc/bin', |
| | ] |
| | for pattern in patterns: |
| | matches = sorted(glob.glob(pattern)) |
| | if len(matches): |
| | return matches[-1] |
| | return None |
| |
|
| | |
| | |
| |
|
| | _cached_plugins = dict() |
| |
|
| | def get_plugin(module_name, sources, **build_kwargs): |
| | assert verbosity in ['none', 'brief', 'full'] |
| |
|
| | |
| | if module_name in _cached_plugins: |
| | return _cached_plugins[module_name] |
| |
|
| | |
| | if verbosity == 'full': |
| | print(f'Setting up PyTorch plugin "{module_name}"...') |
| | elif verbosity == 'brief': |
| | print(f'Setting up PyTorch plugin "{module_name}"... ', end='', flush=True) |
| |
|
| | try: |
| | |
| | if os.name == 'nt' and os.system("where cl.exe >nul 2>nul") != 0: |
| | compiler_bindir = _find_compiler_bindir() |
| | if compiler_bindir is None: |
| | raise RuntimeError(f'Could not find MSVC/GCC/CLANG installation on this computer. Check _find_compiler_bindir() in "{__file__}".') |
| | os.environ['PATH'] += ';' + compiler_bindir |
| |
|
| | |
| | verbose_build = (verbosity == 'full') |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | source_dirs_set = set(os.path.dirname(source) for source in sources) |
| | if len(source_dirs_set) == 1 and ('TORCH_EXTENSIONS_DIR' in os.environ): |
| | all_source_files = sorted(list(x for x in Path(list(source_dirs_set)[0]).iterdir() if x.is_file())) |
| |
|
| | |
| | |
| | hash_md5 = hashlib.md5() |
| | for src in all_source_files: |
| | with open(src, 'rb') as f: |
| | hash_md5.update(f.read()) |
| | build_dir = torch.utils.cpp_extension._get_build_directory(module_name, verbose=verbose_build) |
| | digest_build_dir = os.path.join(build_dir, hash_md5.hexdigest()) |
| |
|
| | if not os.path.isdir(digest_build_dir): |
| | os.makedirs(digest_build_dir, exist_ok=True) |
| | baton = FileBaton(os.path.join(digest_build_dir, 'lock')) |
| | if baton.try_acquire(): |
| | try: |
| | for src in all_source_files: |
| | shutil.copyfile(src, os.path.join(digest_build_dir, os.path.basename(src))) |
| | finally: |
| | baton.release() |
| | else: |
| | |
| | |
| | baton.wait() |
| | digest_sources = [os.path.join(digest_build_dir, os.path.basename(x)) for x in sources] |
| | torch.utils.cpp_extension.load(name=module_name, build_directory=build_dir, |
| | verbose=verbose_build, sources=digest_sources, **build_kwargs) |
| | else: |
| | torch.utils.cpp_extension.load(name=module_name, verbose=verbose_build, sources=sources, **build_kwargs) |
| | module = importlib.import_module(module_name) |
| |
|
| | except: |
| | if verbosity == 'brief': |
| | print('Failed!') |
| | raise |
| |
|
| | |
| | if verbosity == 'full': |
| | print(f'Done setting up PyTorch plugin "{module_name}".') |
| | elif verbosity == 'brief': |
| | print('Done.') |
| | _cached_plugins[module_name] = module |
| | return module |
| |
|
| | |
| |
|