Update pytorch-xla-env-setup.py
Browse files- pytorch-xla-env-setup.py +161 -0
pytorch-xla-env-setup.py
ADDED
|
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# Sample usage:
|
| 3 |
+
# python env-setup.py --version 1.5 --apt-packages libomp5
|
| 4 |
+
import argparse
|
| 5 |
+
import collections
|
| 6 |
+
from datetime import datetime
|
| 7 |
+
import os
|
| 8 |
+
import platform
|
| 9 |
+
import re
|
| 10 |
+
import requests
|
| 11 |
+
import subprocess
|
| 12 |
+
import threading
|
| 13 |
+
import sys
|
| 14 |
+
|
| 15 |
+
VersionConfig = collections.namedtuple('VersionConfig',
|
| 16 |
+
['wheels', 'tpu', 'py_version', 'cuda_version'])
|
| 17 |
+
DEFAULT_CUDA_VERSION = '10.2'
|
| 18 |
+
OLDEST_VERSION = datetime.strptime('20200318', '%Y%m%d')
|
| 19 |
+
OLDEST_GPU_VERSION = datetime.strptime('20200707', '%Y%m%d')
|
| 20 |
+
DIST_BUCKET = 'gs://tpu-pytorch/wheels'
|
| 21 |
+
TORCH_WHEEL_TMPL = 'torch-{whl_version}-cp{py_version}-cp{py_version}m-linux_x86_64.whl'
|
| 22 |
+
TORCH_XLA_WHEEL_TMPL = 'torch_xla-{whl_version}-cp{py_version}-cp{py_version}m-linux_x86_64.whl'
|
| 23 |
+
TORCHVISION_WHEEL_TMPL = 'torchvision-{whl_version}-cp{py_version}-cp{py_version}m-linux_x86_64.whl'
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def is_gpu_runtime():
|
| 27 |
+
return os.environ.get('COLAB_GPU', 0) == 1
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def is_tpu_runtime():
|
| 31 |
+
return 'TPU_NAME' in os.environ
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def update_tpu_runtime(tpu_name, version):
|
| 35 |
+
print(f'Updating TPU runtime to {version.tpu} ...')
|
| 36 |
+
|
| 37 |
+
try:
|
| 38 |
+
import cloud_tpu_client
|
| 39 |
+
except ImportError:
|
| 40 |
+
subprocess.call([sys.executable, '-m', 'pip', 'install', 'cloud-tpu-client'])
|
| 41 |
+
import cloud_tpu_client
|
| 42 |
+
|
| 43 |
+
client = cloud_tpu_client.Client(tpu_name)
|
| 44 |
+
client.configure_tpu_version(version.tpu)
|
| 45 |
+
print('Done updating TPU runtime')
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def get_py_version():
|
| 49 |
+
version_tuple = platform.python_version_tuple()
|
| 50 |
+
return version_tuple[0] + version_tuple[1] # major_version + minor_version
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def get_cuda_version():
|
| 54 |
+
if is_gpu_runtime():
|
| 55 |
+
# cuda available, install cuda wheels
|
| 56 |
+
return DEFAULT_CUDA_VERSION
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def get_version(version):
|
| 60 |
+
cuda_version = get_cuda_version()
|
| 61 |
+
if version == 'nightly':
|
| 62 |
+
return VersionConfig(
|
| 63 |
+
'nightly', 'pytorch-nightly', get_py_version(), cuda_version)
|
| 64 |
+
|
| 65 |
+
version_date = None
|
| 66 |
+
try:
|
| 67 |
+
version_date = datetime.strptime(version, '%Y%m%d')
|
| 68 |
+
except ValueError:
|
| 69 |
+
pass # Not a dated nightly.
|
| 70 |
+
|
| 71 |
+
if version_date:
|
| 72 |
+
if cuda_version and version_date < OLDEST_GPU_VERSION:
|
| 73 |
+
raise ValueError(
|
| 74 |
+
f'Oldest nightly version build with CUDA available is {OLDEST_GPU_VERSION}')
|
| 75 |
+
elif not cuda_version and version_date < OLDEST_VERSION:
|
| 76 |
+
raise ValueError(f'Oldest nightly version available is {OLDEST_VERSION}')
|
| 77 |
+
return VersionConfig(f'nightly+{version}', f'pytorch-dev{version}',
|
| 78 |
+
get_py_version(), cuda_version)
|
| 79 |
+
|
| 80 |
+
version_regex = re.compile('^(\d+\.)+\d+$')
|
| 81 |
+
if not version_regex.match(version):
|
| 82 |
+
raise ValueError(f'{version} is an invalid torch_xla version pattern')
|
| 83 |
+
return VersionConfig(
|
| 84 |
+
version, f'pytorch-{version}', get_py_version(), cuda_version)
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def install_vm(version, apt_packages, is_root=False):
|
| 88 |
+
dist_bucket = DIST_BUCKET
|
| 89 |
+
if version.cuda_version:
|
| 90 |
+
dist_bucket = os.path.join(
|
| 91 |
+
DIST_BUCKET, 'cuda/{}'.format(version.cuda_version.replace('.', '')))
|
| 92 |
+
torch_whl = TORCH_WHEEL_TMPL.format(
|
| 93 |
+
whl_version=version.wheels, py_version=version.py_version)
|
| 94 |
+
torch_whl_path = os.path.join(dist_bucket, torch_whl)
|
| 95 |
+
torch_xla_whl = TORCH_XLA_WHEEL_TMPL.format(
|
| 96 |
+
whl_version=version.wheels, py_version=version.py_version)
|
| 97 |
+
torch_xla_whl_path = os.path.join(dist_bucket, torch_xla_whl)
|
| 98 |
+
torchvision_whl = TORCHVISION_WHEEL_TMPL.format(
|
| 99 |
+
whl_version=version.wheels, py_version=version.py_version)
|
| 100 |
+
torchvision_whl_path = os.path.join(dist_bucket, torchvision_whl)
|
| 101 |
+
apt_cmd = ['apt-get', 'install', '-y']
|
| 102 |
+
apt_cmd.extend(apt_packages)
|
| 103 |
+
|
| 104 |
+
if not is_root:
|
| 105 |
+
# Colab/Kaggle run as root, but not GCE VMs so we need privilege
|
| 106 |
+
apt_cmd.insert(0, 'sudo')
|
| 107 |
+
|
| 108 |
+
installation_cmds = [
|
| 109 |
+
[sys.executable, '-m', 'pip', 'uninstall', '-y', 'torch', 'torchvision'],
|
| 110 |
+
['gsutil', 'cp', torch_whl_path, '.'],
|
| 111 |
+
['gsutil', 'cp', torch_xla_whl_path, '.'],
|
| 112 |
+
['gsutil', 'cp', torchvision_whl_path, '.'],
|
| 113 |
+
[sys.executable, '-m', 'pip', 'install', torch_whl],
|
| 114 |
+
[sys.executable, '-m', 'pip', 'install', torch_xla_whl],
|
| 115 |
+
[sys.executable, '-m', 'pip', 'install', torchvision_whl],
|
| 116 |
+
apt_cmd,
|
| 117 |
+
]
|
| 118 |
+
for cmd in installation_cmds:
|
| 119 |
+
subprocess.call(cmd)
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
def run_setup(args):
|
| 123 |
+
version = get_version(args.version)
|
| 124 |
+
# Update TPU
|
| 125 |
+
print('Updating... This may take around 2 minutes.')
|
| 126 |
+
|
| 127 |
+
if is_tpu_runtime():
|
| 128 |
+
update = threading.Thread(
|
| 129 |
+
target=update_tpu_runtime, args=(
|
| 130 |
+
args.tpu,
|
| 131 |
+
version,
|
| 132 |
+
))
|
| 133 |
+
update.start()
|
| 134 |
+
|
| 135 |
+
install_vm(version, args.apt_packages, is_root=not args.tpu)
|
| 136 |
+
|
| 137 |
+
if is_tpu_runtime():
|
| 138 |
+
update.join()
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
if __name__ == '__main__':
|
| 142 |
+
parser = argparse.ArgumentParser()
|
| 143 |
+
parser.add_argument(
|
| 144 |
+
'--version',
|
| 145 |
+
type=str,
|
| 146 |
+
default='20200515',
|
| 147 |
+
help='Versions to install (nightly, release version, or YYYYMMDD).',
|
| 148 |
+
)
|
| 149 |
+
parser.add_argument(
|
| 150 |
+
'--apt-packages',
|
| 151 |
+
nargs='+',
|
| 152 |
+
default=['libomp5'],
|
| 153 |
+
help='List of apt packages to install',
|
| 154 |
+
)
|
| 155 |
+
parser.add_argument(
|
| 156 |
+
'--tpu',
|
| 157 |
+
type=str,
|
| 158 |
+
help='[GCP] Name of the TPU (same zone, project as VM running script)',
|
| 159 |
+
)
|
| 160 |
+
args = parser.parse_args()
|
| 161 |
+
run_setup(args)
|