|
|
import collections |
|
|
import os |
|
|
import shutil |
|
|
import subprocess |
|
|
|
|
|
try: |
|
|
|
|
|
import conda.cli.python_api |
|
|
from conda.cli.python_api import Commands as conda_commands |
|
|
except ImportError: |
|
|
|
|
|
|
|
|
pass |
|
|
|
|
|
|
|
|
WORKING_ROOT = "/tmp/pytorch_blas_compare_environments" |
|
|
MKL_2020_3 = "mkl_2020_3" |
|
|
MKL_2020_0 = "mkl_2020_0" |
|
|
OPEN_BLAS = "open_blas" |
|
|
EIGEN = "eigen" |
|
|
|
|
|
|
|
|
GENERIC_ENV_VARS = ("USE_CUDA=0", "USE_ROCM=0") |
|
|
BASE_PKG_DEPS = ( |
|
|
"cffi", |
|
|
"cmake", |
|
|
"hypothesis", |
|
|
"ninja", |
|
|
"numpy", |
|
|
"pyyaml", |
|
|
"setuptools", |
|
|
"typing_extensions", |
|
|
) |
|
|
|
|
|
|
|
|
SubEnvSpec = collections.namedtuple( |
|
|
"SubEnvSpec", ( |
|
|
"generic_installs", |
|
|
"special_installs", |
|
|
"environment_variables", |
|
|
|
|
|
|
|
|
"expected_blas_symbols", |
|
|
"expected_mkl_version", |
|
|
)) |
|
|
|
|
|
|
|
|
SUB_ENVS = { |
|
|
MKL_2020_3: SubEnvSpec( |
|
|
generic_installs=(), |
|
|
special_installs=("intel", ("mkl=2020.3", "mkl-include=2020.3")), |
|
|
environment_variables=("BLAS=MKL",) + GENERIC_ENV_VARS, |
|
|
expected_blas_symbols=("mkl_blas_sgemm",), |
|
|
expected_mkl_version="2020.0.3", |
|
|
), |
|
|
|
|
|
MKL_2020_0: SubEnvSpec( |
|
|
generic_installs=(), |
|
|
special_installs=("intel", ("mkl=2020.0", "mkl-include=2020.0")), |
|
|
environment_variables=("BLAS=MKL",) + GENERIC_ENV_VARS, |
|
|
expected_blas_symbols=("mkl_blas_sgemm",), |
|
|
expected_mkl_version="2020.0.0", |
|
|
), |
|
|
|
|
|
OPEN_BLAS: SubEnvSpec( |
|
|
generic_installs=("openblas",), |
|
|
special_installs=(), |
|
|
environment_variables=("BLAS=OpenBLAS",) + GENERIC_ENV_VARS, |
|
|
expected_blas_symbols=("exec_blas",), |
|
|
expected_mkl_version=None, |
|
|
), |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
def conda_run(*args): |
|
|
"""Convenience method.""" |
|
|
stdout, stderr, retcode = conda.cli.python_api.run_command(*args) |
|
|
if retcode: |
|
|
raise OSError(f"conda error: {str(args)} retcode: {retcode}\n{stderr}") |
|
|
|
|
|
return stdout |
|
|
|
|
|
|
|
|
def main(): |
|
|
if os.path.exists(WORKING_ROOT): |
|
|
print("Cleaning: removing old working root.") |
|
|
shutil.rmtree(WORKING_ROOT) |
|
|
os.makedirs(WORKING_ROOT) |
|
|
|
|
|
git_root = subprocess.check_output( |
|
|
"git rev-parse --show-toplevel", |
|
|
shell=True, |
|
|
cwd=os.path.dirname(os.path.realpath(__file__)) |
|
|
).decode("utf-8").strip() |
|
|
|
|
|
for env_name, env_spec in SUB_ENVS.items(): |
|
|
env_path = os.path.join(WORKING_ROOT, env_name) |
|
|
print(f"Creating env: {env_name}: ({env_path})") |
|
|
conda_run( |
|
|
conda_commands.CREATE, |
|
|
"--no-default-packages", |
|
|
"--prefix", env_path, |
|
|
"python=3", |
|
|
) |
|
|
|
|
|
print("Testing that env can be activated:") |
|
|
base_source = subprocess.run( |
|
|
f"source activate {env_path}", |
|
|
shell=True, |
|
|
stdout=subprocess.PIPE, |
|
|
stderr=subprocess.PIPE, |
|
|
) |
|
|
if base_source.returncode: |
|
|
raise OSError( |
|
|
"Failed to source base environment:\n" |
|
|
f" stdout: {base_source.stdout.decode('utf-8')}\n" |
|
|
f" stderr: {base_source.stderr.decode('utf-8')}" |
|
|
) |
|
|
|
|
|
print("Installing packages:") |
|
|
conda_run( |
|
|
conda_commands.INSTALL, |
|
|
"--prefix", env_path, |
|
|
*(BASE_PKG_DEPS + env_spec.generic_installs) |
|
|
) |
|
|
|
|
|
if env_spec.special_installs: |
|
|
channel, channel_deps = env_spec.special_installs |
|
|
print(f"Installing packages from channel: {channel}") |
|
|
conda_run( |
|
|
conda_commands.INSTALL, |
|
|
"--prefix", env_path, |
|
|
"-c", channel, *channel_deps |
|
|
) |
|
|
|
|
|
if env_spec.environment_variables: |
|
|
print("Setting environment variables.") |
|
|
|
|
|
|
|
|
env_set = subprocess.run( |
|
|
f"source activate {env_path} && " |
|
|
f"conda env config vars set {' '.join(env_spec.environment_variables)}", |
|
|
shell=True, |
|
|
stdout=subprocess.PIPE, |
|
|
stderr=subprocess.PIPE, |
|
|
) |
|
|
if env_set.returncode: |
|
|
raise OSError( |
|
|
"Failed to set environment variables:\n" |
|
|
f" stdout: {env_set.stdout.decode('utf-8')}\n" |
|
|
f" stderr: {env_set.stderr.decode('utf-8')}" |
|
|
) |
|
|
|
|
|
|
|
|
actual_env_vars = subprocess.run( |
|
|
f"source activate {env_path} && env", |
|
|
shell=True, |
|
|
stdout=subprocess.PIPE, |
|
|
stderr=subprocess.PIPE, |
|
|
).stdout.decode("utf-8").strip().splitlines() |
|
|
for e in env_spec.environment_variables: |
|
|
assert e in actual_env_vars, f"{e} not in envs" |
|
|
|
|
|
print(f"Building PyTorch for env: `{env_name}`") |
|
|
|
|
|
|
|
|
build_run = subprocess.run( |
|
|
f"source activate {env_path} && " |
|
|
f"cd {git_root} && " |
|
|
"python setup.py install --cmake", |
|
|
shell=True, |
|
|
stdout=subprocess.PIPE, |
|
|
stderr=subprocess.PIPE, |
|
|
) |
|
|
|
|
|
print("Checking configuration:") |
|
|
check_run = subprocess.run( |
|
|
|
|
|
f"source activate {env_path} && " |
|
|
"python -c \"" |
|
|
"import torch;" |
|
|
"from torch.utils.benchmark import Timer;" |
|
|
"print(torch.__config__.show());" |
|
|
"setup = 'x=torch.ones((128, 128));y=torch.ones((128, 128))';" |
|
|
"counts = Timer('torch.mm(x, y)', setup).collect_callgrind(collect_baseline=False);" |
|
|
"stats = counts.as_standardized().stats(inclusive=True);" |
|
|
"print(stats.filter(lambda l: 'blas' in l.lower()))\"", |
|
|
shell=True, |
|
|
stdout=subprocess.PIPE, |
|
|
stderr=subprocess.PIPE, |
|
|
) |
|
|
if check_run.returncode: |
|
|
raise OSError( |
|
|
"Failed to set environment variables:\n" |
|
|
f" stdout: {check_run.stdout.decode('utf-8')}\n" |
|
|
f" stderr: {check_run.stderr.decode('utf-8')}" |
|
|
) |
|
|
check_run_stdout = check_run.stdout.decode('utf-8') |
|
|
print(check_run_stdout) |
|
|
|
|
|
for e in env_spec.environment_variables: |
|
|
if "BLAS" in e: |
|
|
assert e in check_run_stdout, f"PyTorch build did not respect `BLAS=...`: {e}" |
|
|
|
|
|
for s in env_spec.expected_blas_symbols: |
|
|
assert s in check_run_stdout |
|
|
|
|
|
if env_spec.expected_mkl_version is not None: |
|
|
assert f"- Intel(R) Math Kernel Library Version {env_spec.expected_mkl_version}" in check_run_stdout |
|
|
|
|
|
print(f"Build complete: {env_name}") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|