| | |
| | |
| | |
| | import argparse |
| | import collections |
| | from datetime import datetime |
| | import os |
| | import platform |
| | import re |
| | import requests |
| | import subprocess |
| | import threading |
| | import sys |
| |
|
| | VersionConfig = collections.namedtuple('VersionConfig', |
| | ['wheels', 'tpu', 'py_version', 'cuda_version']) |
| | DEFAULT_CUDA_VERSION = '10.2' |
| | OLDEST_VERSION = datetime.strptime('20200318', '%Y%m%d') |
| | OLDEST_GPU_VERSION = datetime.strptime('20200707', '%Y%m%d') |
| | DIST_BUCKET = 'gs://tpu-pytorch/wheels' |
| | TORCH_WHEEL_TMPL = 'torch-{whl_version}-cp{py_version}-cp{py_version}m-linux_x86_64.whl' |
| | TORCH_XLA_WHEEL_TMPL = 'torch_xla-{whl_version}-cp{py_version}-cp{py_version}m-linux_x86_64.whl' |
| | TORCHVISION_WHEEL_TMPL = 'torchvision-{whl_version}-cp{py_version}-cp{py_version}m-linux_x86_64.whl' |
| |
|
| |
|
| | def is_gpu_runtime(): |
| | return os.environ.get('COLAB_GPU', 0) == 1 |
| |
|
| |
|
| | def is_tpu_runtime(): |
| | return 'TPU_NAME' in os.environ |
| |
|
| |
|
| | def update_tpu_runtime(tpu_name, version): |
| | print(f'Updating TPU runtime to {version.tpu} ...') |
| |
|
| | try: |
| | import cloud_tpu_client |
| | except ImportError: |
| | subprocess.call([sys.executable, '-m', 'pip', 'install', 'cloud-tpu-client']) |
| | import cloud_tpu_client |
| |
|
| | client = cloud_tpu_client.Client(tpu_name) |
| | client.configure_tpu_version(version.tpu) |
| | print('Done updating TPU runtime') |
| |
|
| |
|
| | def get_py_version(): |
| | version_tuple = platform.python_version_tuple() |
| | return version_tuple[0] + version_tuple[1] |
| |
|
| |
|
| | def get_cuda_version(): |
| | if is_gpu_runtime(): |
| | |
| | return DEFAULT_CUDA_VERSION |
| |
|
| |
|
| | def get_version(version): |
| | cuda_version = get_cuda_version() |
| | if version == 'nightly': |
| | return VersionConfig( |
| | 'nightly', 'pytorch-nightly', get_py_version(), cuda_version) |
| |
|
| | version_date = None |
| | try: |
| | version_date = datetime.strptime(version, '%Y%m%d') |
| | except ValueError: |
| | pass |
| |
|
| | if version_date: |
| | if cuda_version and version_date < OLDEST_GPU_VERSION: |
| | raise ValueError( |
| | f'Oldest nightly version build with CUDA available is {OLDEST_GPU_VERSION}') |
| | elif not cuda_version and version_date < OLDEST_VERSION: |
| | raise ValueError(f'Oldest nightly version available is {OLDEST_VERSION}') |
| | return VersionConfig(f'nightly+{version}', f'pytorch-dev{version}', |
| | get_py_version(), cuda_version) |
| |
|
| | version_regex = re.compile('^(\d+\.)+\d+$') |
| | if not version_regex.match(version): |
| | raise ValueError(f'{version} is an invalid torch_xla version pattern') |
| | return VersionConfig( |
| | version, f'pytorch-{version}', get_py_version(), cuda_version) |
| |
|
| |
|
| | def install_vm(version, apt_packages, is_root=False): |
| | dist_bucket = DIST_BUCKET |
| | if version.cuda_version: |
| | dist_bucket = os.path.join( |
| | DIST_BUCKET, 'cuda/{}'.format(version.cuda_version.replace('.', ''))) |
| | torch_whl = TORCH_WHEEL_TMPL.format( |
| | whl_version=version.wheels, py_version=version.py_version) |
| | torch_whl_path = os.path.join(dist_bucket, torch_whl) |
| | torch_xla_whl = TORCH_XLA_WHEEL_TMPL.format( |
| | whl_version=version.wheels, py_version=version.py_version) |
| | torch_xla_whl_path = os.path.join(dist_bucket, torch_xla_whl) |
| | torchvision_whl = TORCHVISION_WHEEL_TMPL.format( |
| | whl_version=version.wheels, py_version=version.py_version) |
| | torchvision_whl_path = os.path.join(dist_bucket, torchvision_whl) |
| | apt_cmd = ['apt-get', 'install', '-y'] |
| | apt_cmd.extend(apt_packages) |
| |
|
| | if not is_root: |
| | |
| | apt_cmd.insert(0, 'sudo') |
| |
|
| | installation_cmds = [ |
| | [sys.executable, '-m', 'pip', 'uninstall', '-y', 'torch', 'torchvision'], |
| | ['gsutil', 'cp', torch_whl_path, '.'], |
| | ['gsutil', 'cp', torch_xla_whl_path, '.'], |
| | ['gsutil', 'cp', torchvision_whl_path, '.'], |
| | [sys.executable, '-m', 'pip', 'install', torch_whl], |
| | [sys.executable, '-m', 'pip', 'install', torch_xla_whl], |
| | [sys.executable, '-m', 'pip', 'install', torchvision_whl], |
| | apt_cmd, |
| | ] |
| | for cmd in installation_cmds: |
| | subprocess.call(cmd) |
| |
|
| |
|
| | def run_setup(args): |
| | version = get_version(args.version) |
| | |
| | print('Updating... This may take around 2 minutes.') |
| |
|
| | if is_tpu_runtime(): |
| | update = threading.Thread( |
| | target=update_tpu_runtime, args=( |
| | args.tpu, |
| | version, |
| | )) |
| | update.start() |
| |
|
| | install_vm(version, args.apt_packages, is_root=not args.tpu) |
| |
|
| | if is_tpu_runtime(): |
| | update.join() |
| |
|
| |
|
| | if __name__ == '__main__': |
| | parser = argparse.ArgumentParser() |
| | parser.add_argument( |
| | '--version', |
| | type=str, |
| | default='20200515', |
| | help='Versions to install (nightly, release version, or YYYYMMDD).', |
| | ) |
| | parser.add_argument( |
| | '--apt-packages', |
| | nargs='+', |
| | default=['libomp5'], |
| | help='List of apt packages to install', |
| | ) |
| | parser.add_argument( |
| | '--tpu', |
| | type=str, |
| | help='[GCP] Name of the TPU (same zone, project as VM running script)', |
| | ) |
| | args = parser.parse_args() |
| | run_setup(args) |
| |
|