|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os
|
|
|
import glob
|
|
|
import torch
|
|
|
import torch.utils.cpp_extension
|
|
|
import importlib
|
|
|
import hashlib
|
|
|
import shutil
|
|
|
from pathlib import Path
|
|
|
|
|
|
from torch.utils.file_baton import FileBaton
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
verbosity = 'brief'
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _find_compiler_bindir():
|
|
|
patterns = [
|
|
|
'C:/Program Files (x86)/Microsoft Visual Studio/*/Professional/VC/Tools/MSVC/*/bin/Hostx64/x64',
|
|
|
'C:/Program Files (x86)/Microsoft Visual Studio/*/BuildTools/VC/Tools/MSVC/*/bin/Hostx64/x64',
|
|
|
'C:/Program Files (x86)/Microsoft Visual Studio/*/Community/VC/Tools/MSVC/*/bin/Hostx64/x64',
|
|
|
'C:/Program Files (x86)/Microsoft Visual Studio */vc/bin',
|
|
|
]
|
|
|
for pattern in patterns:
|
|
|
matches = sorted(glob.glob(pattern))
|
|
|
if len(matches):
|
|
|
return matches[-1]
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
_cached_plugins = dict()
|
|
|
|
|
|
def get_plugin(module_name, sources, **build_kwargs):
|
|
|
assert verbosity in ['none', 'brief', 'full']
|
|
|
|
|
|
|
|
|
if module_name in _cached_plugins:
|
|
|
return _cached_plugins[module_name]
|
|
|
|
|
|
|
|
|
if verbosity == 'full':
|
|
|
print(f'Setting up PyTorch plugin "{module_name}"...')
|
|
|
elif verbosity == 'brief':
|
|
|
print(f'Setting up PyTorch plugin "{module_name}"... ', end='', flush=True)
|
|
|
|
|
|
try:
|
|
|
|
|
|
if os.name == 'nt' and os.system("where cl.exe >nul 2>nul") != 0:
|
|
|
compiler_bindir = _find_compiler_bindir()
|
|
|
if compiler_bindir is None:
|
|
|
raise RuntimeError(f'Could not find MSVC/GCC/CLANG installation on this computer. Check _find_compiler_bindir() in "{__file__}".')
|
|
|
os.environ['PATH'] += ';' + compiler_bindir
|
|
|
|
|
|
|
|
|
verbose_build = (verbosity == 'full')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
source_dirs_set = set(os.path.dirname(source) for source in sources)
|
|
|
if len(source_dirs_set) == 1 and ('TORCH_EXTENSIONS_DIR' in os.environ):
|
|
|
all_source_files = sorted(list(x for x in Path(list(source_dirs_set)[0]).iterdir() if x.is_file()))
|
|
|
|
|
|
|
|
|
|
|
|
hash_md5 = hashlib.md5()
|
|
|
for src in all_source_files:
|
|
|
with open(src, 'rb') as f:
|
|
|
hash_md5.update(f.read())
|
|
|
build_dir = torch.utils.cpp_extension._get_build_directory(module_name, verbose=verbose_build)
|
|
|
digest_build_dir = os.path.join(build_dir, hash_md5.hexdigest())
|
|
|
|
|
|
if not os.path.isdir(digest_build_dir):
|
|
|
os.makedirs(digest_build_dir, exist_ok=True)
|
|
|
baton = FileBaton(os.path.join(digest_build_dir, 'lock'))
|
|
|
if baton.try_acquire():
|
|
|
try:
|
|
|
for src in all_source_files:
|
|
|
shutil.copyfile(src, os.path.join(digest_build_dir, os.path.basename(src)))
|
|
|
finally:
|
|
|
baton.release()
|
|
|
else:
|
|
|
|
|
|
|
|
|
baton.wait()
|
|
|
digest_sources = [os.path.join(digest_build_dir, os.path.basename(x)) for x in sources]
|
|
|
torch.utils.cpp_extension.load(name=module_name, build_directory=build_dir,
|
|
|
verbose=verbose_build, sources=digest_sources, **build_kwargs)
|
|
|
else:
|
|
|
torch.utils.cpp_extension.load(name=module_name, verbose=verbose_build, sources=sources, **build_kwargs)
|
|
|
module = importlib.import_module(module_name)
|
|
|
|
|
|
except:
|
|
|
if verbosity == 'brief':
|
|
|
print('Failed!')
|
|
|
raise
|
|
|
|
|
|
|
|
|
if verbosity == 'full':
|
|
|
print(f'Done setting up PyTorch plugin "{module_name}".')
|
|
|
elif verbosity == 'brief':
|
|
|
print('Done.')
|
|
|
_cached_plugins[module_name] = module
|
|
|
return module
|
|
|
|
|
|
|
|
|
|