from packaging import version from typing import List, Dict from pathlib import Path import pkg_resources, subprocess, requests, zipfile, launch, sys, os base = Path(__file__).parent req_ = base / "requirements.txt" orange_ = '\033[38;5;208m' blue_ = '\033[38;5;39m' reset_ = '\033[0m' def _sub(inputs: List[str]) -> bool: try: subprocess.run( inputs, check=True, stdout=subprocess.DEVNULL, stderr=subprocess.STDOUT ) return True except subprocess.CalledProcessError: return False def _check_req(pkg: str, args: str, cmd: str, pkg_list: List[str]) -> None: try: subprocess.run( [pkg, args], check=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL ) except FileNotFoundError: pkg_list.append(pkg) _sub(cmd.split()) def _install_req_1() -> None: reqs = [] names = [] with open(req_) as file: for pkg in file: pkg = pkg.strip() if '==' in pkg: pkg_name, pkg_version = pkg.split('==') try: _version = pkg_resources.get_distribution(pkg_name).version if version.parse(_version) < version.parse(pkg_version): reqs.append(pkg) names.append(pkg_name) except pkg_resources.DistributionNotFound: reqs.append(pkg) names.append(pkg_name) else: if not launch.is_installed(pkg): reqs.append(pkg) names.append(pkg) if not sys.platform == 'win32': if not launch.is_installed('aria2'): reqs.append('aria2') names.append('aria2') if reqs: print( f"Installing SD-Hub requirement: " f"{' '.join(f'{orange_}{pkg}{reset_}' for pkg in names)}" ) for pkg in reqs: subprocess.run( [sys.executable, '-m', 'pip', 'install', '-q', pkg] ) def _install_req_2() -> None: pkg_list: List[str] = [] if sys.platform == 'win32': aria2_exe = base / 'aria2c.exe' if not launch.is_installed('lz4'): pkg_list.append('lz4') if not aria2_exe.exists(): pkg_list.append('aria2') for pkg_name in pkg_list: if pkg_name == 'lz4': subprocess.run([sys.executable, '-m', 'pip', 'install', '-q', 'lz4']) elif pkg_name == 'aria2': aria2_url = 'https://github.com/aria2/aria2/releases/download/release-1.37.0/aria2-1.37.0-win-64bit-build1.zip' with requests.get(aria2_url, stream=True) as r: r.raise_for_status() aria2_zip = base / Path(aria2_url).name with open(aria2_zip, 'wb') as f: for chunk in r.iter_content(chunk_size=8192): f.write(chunk) with zipfile.ZipFile(aria2_zip, 'r') as zip_ref: for f in zip_ref.infolist(): if f.filename.endswith('aria2c.exe'): f.filename = Path(f.filename).name zip_ref.extract(f, base) break aria2_zip.unlink() else: env_list: Dict[str, str] = { 'Colab': 'COLAB_JUPYTER_TRANSPORT', 'SageMaker Studio Lab': 'SAGEMAKER_INTERNAL_IMAGE_URI', 'Kaggle': 'KAGGLE_DATA_PROXY_TOKEN' } env = 'Unknown' for envs, var in env_list.items(): if var in os.environ: env = envs break pkg_cmds: Dict[str, str] = { 'apt': 'update', 'conda': '--version' } pv_lz4: Dict[str, str] = { 'pv': '-V', 'lz4': '-V' } if env in ['Colab', 'Kaggle']: if _sub(['apt', pkg_cmds['apt']]): for pkg, args in pv_lz4.items(): _check_req(pkg, args, f"apt -y install {pkg}", pkg_list) elif env == 'SageMaker Studio Lab': if _sub(['conda', pkg_cmds['conda']]): for pkg, args in pv_lz4.items(): _check_req(pkg, args, f"conda install -qyc conda-forge {pkg}", pkg_list) elif env == 'Unknown': if _sub(['apt', pkg_cmds['apt']]): for pkg, args in pv_lz4.items(): _check_req(pkg, args, f"apt -y install {pkg}", pkg_list) elif _sub(['conda', pkg_cmds['conda']]): for pkg, args in pv_lz4.items(): _check_req(pkg, args, f"conda install -qyc conda-forge {pkg}", pkg_list) else: print("SD-Hub: Failed to install pv and lz4 in an unknown environment") if pkg_list: print( f"Installing SD-Hub requirement: " f"{' '.join(f'{blue_}{pkg}{reset_}' for pkg in pkg_list)}" ) _install_req_1() _install_req_2()