| | import urllib.request |
| | import tempfile |
| | import os |
| | import uuid |
| | import shutil |
| | import glob |
| |
|
| | import yaml |
| | import hashlib |
| |
|
| | from zipfile import ZipFile |
| | from sys import platform |
| | from typing import Tuple, Optional, Dict, Any |
| |
|
| | from filelock import FileLock |
| |
|
| | from mlagents_envs.env_utils import validate_environment_path |
| |
|
| | from mlagents_envs.logging_util import get_logger |
| |
|
| | logger = get_logger(__name__) |
| |
|
| | |
| | BLOCK_SIZE = 8192 |
| |
|
| |
|
| | def get_local_binary_path(name: str, url: str, tmp_dir: Optional[str] = None) -> str: |
| | """ |
| | Returns the path to the executable previously downloaded with the name argument. If |
| | None is found, the executable at the url argument will be downloaded and stored |
| | under name for future uses. |
| | :param name: The name that will be given to the folder containing the extracted data |
| | :param url: The URL of the zip file |
| | :param: tmp_dir: Optional override for the temporary directory to save binaries and zips in. |
| | """ |
| | NUMBER_ATTEMPTS = 5 |
| | tmp_dir = tmp_dir or tempfile.gettempdir() |
| | lock = FileLock(os.path.join(tmp_dir, name + ".lock")) |
| | with lock: |
| | path = get_local_binary_path_if_exists(name, url, tmp_dir=tmp_dir) |
| | if path is None: |
| | logger.debug( |
| | f"Local environment {name} not found, downloading environment from {url}" |
| | ) |
| | for attempt in range( |
| | NUMBER_ATTEMPTS |
| | ): |
| | if path is not None: |
| | break |
| | try: |
| | download_and_extract_zip(url, name, tmp_dir=tmp_dir) |
| | except Exception: |
| | if attempt + 1 < NUMBER_ATTEMPTS: |
| | logger.warning( |
| | f"Attempt {attempt + 1} / {NUMBER_ATTEMPTS}" |
| | ": Failed to download and extract binary." |
| | ) |
| | else: |
| | raise |
| | path = get_local_binary_path_if_exists(name, url, tmp_dir=tmp_dir) |
| |
|
| | if path is None: |
| | raise FileNotFoundError( |
| | f"Binary not found, make sure {url} is a valid url to " |
| | "a zip folder containing a valid Unity executable" |
| | ) |
| | return path |
| |
|
| |
|
| | def get_local_binary_path_if_exists(name: str, url: str, tmp_dir: str) -> Optional[str]: |
| | """ |
| | Recursively searches for a Unity executable in the extracted files folders. This is |
| | platform dependent : It will only return a Unity executable compatible with the |
| | computer's OS. If no executable is found, None will be returned. |
| | :param name: The name/identifier of the executable |
| | :param url: The url the executable was downloaded from (for verification) |
| | :param: tmp_dir: Optional override for the temporary directory to save binaries and zips in. |
| | """ |
| | _, bin_dir = get_tmp_dirs(tmp_dir) |
| | extension = None |
| |
|
| | if platform == "linux" or platform == "linux2": |
| | extension = "*.x86_64" |
| | if platform == "darwin": |
| | extension = "*.app" |
| | if platform == "win32": |
| | extension = "*.exe" |
| | if extension is None: |
| | raise NotImplementedError("No extensions found for this platform.") |
| | url_hash = "-" + hashlib.md5(url.encode()).hexdigest() |
| | path = os.path.join(bin_dir, name + url_hash, "**", extension) |
| | candidates = glob.glob(path, recursive=True) |
| | if len(candidates) == 0: |
| | return None |
| | else: |
| | for c in candidates: |
| | |
| | if "UnityCrashHandler64" not in c: |
| | |
| | if validate_environment_path(c) is None: |
| | shutil.rmtree(c) |
| | return None |
| | return c |
| | return None |
| |
|
| |
|
| | def _get_tmp_dir_helper(tmp_dir: Optional[str] = None) -> Tuple[str, str]: |
| | tmp_dir = tmp_dir or ("/tmp" if platform == "darwin" else tempfile.gettempdir()) |
| | MLAGENTS = "ml-agents-binaries" |
| | TMP_FOLDER_NAME = "tmp" |
| | BINARY_FOLDER_NAME = "binaries" |
| | mla_directory = os.path.join(tmp_dir, MLAGENTS) |
| | if not os.path.exists(mla_directory): |
| | os.makedirs(mla_directory) |
| | os.chmod(mla_directory, 16877) |
| | zip_directory = os.path.join(tmp_dir, MLAGENTS, TMP_FOLDER_NAME) |
| | if not os.path.exists(zip_directory): |
| | os.makedirs(zip_directory) |
| | os.chmod(zip_directory, 16877) |
| | bin_directory = os.path.join(tmp_dir, MLAGENTS, BINARY_FOLDER_NAME) |
| | if not os.path.exists(bin_directory): |
| | os.makedirs(bin_directory) |
| | os.chmod(bin_directory, 16877) |
| | return zip_directory, bin_directory |
| |
|
| |
|
| | def get_tmp_dirs(tmp_dir: Optional[str] = None) -> Tuple[str, str]: |
| | """ |
| | Returns the path to the folder containing the downloaded zip files and the extracted |
| | binaries. If these folders do not exist, they will be created. |
| | :retrun: Tuple containing path to : (zip folder, extracted files folder) |
| | """ |
| | |
| | |
| | for _attempt in range(3): |
| | try: |
| | return _get_tmp_dir_helper(tmp_dir) |
| | except FileExistsError: |
| | continue |
| | return _get_tmp_dir_helper(tmp_dir) |
| |
|
| |
|
| | def download_and_extract_zip( |
| | url: str, name: str, tmp_dir: Optional[str] = None |
| | ) -> None: |
| | """ |
| | Downloads a zip file under a URL, extracts its contents into a folder with the name |
| | argument and gives chmod 755 to all the files it contains. Files are downloaded and |
| | extracted into special folders in the temp folder of the machine. |
| | :param url: The URL of the zip file |
| | :param name: The name that will be given to the folder containing the extracted data |
| | :param: tmp_dir: Optional override for the temporary directory to save binaries and zips in. |
| | """ |
| | zip_dir, bin_dir = get_tmp_dirs(tmp_dir) |
| | url_hash = "-" + hashlib.md5(url.encode()).hexdigest() |
| | binary_path = os.path.join(bin_dir, name + url_hash) |
| | if os.path.exists(binary_path): |
| | shutil.rmtree(binary_path) |
| |
|
| | |
| | try: |
| | request = urllib.request.urlopen(url, timeout=30) |
| | except urllib.error.HTTPError as e: |
| | e.reason = f"{e.reason} {url}" |
| | raise |
| | zip_size = int(request.headers["content-length"]) |
| | zip_file_path = os.path.join(zip_dir, str(uuid.uuid4()) + ".zip") |
| | with open(zip_file_path, "wb") as zip_file: |
| | downloaded = 0 |
| | while True: |
| | buffer = request.read(BLOCK_SIZE) |
| | if not buffer: |
| | |
| | break |
| | downloaded += len(buffer) |
| | zip_file.write(buffer) |
| | downloaded_percent = downloaded / zip_size * 100 |
| | print_progress(f" Downloading {name}", downloaded_percent) |
| | print("") |
| |
|
| | |
| | with ZipFileWithProgress(zip_file_path, "r") as zip_ref: |
| | zip_ref.extract_zip(f" Extracting {name}", binary_path) |
| | print("") |
| |
|
| | |
| | print_progress(f" Cleaning up {name}", 0) |
| | os.remove(zip_file_path) |
| |
|
| | |
| | for f in glob.glob(binary_path + "/**/*", recursive=True): |
| | |
| | os.chmod(f, 16877) |
| | print_progress(f" Cleaning up {name}", 100) |
| | print("") |
| |
|
| |
|
| | def print_progress(prefix: str, percent: float) -> None: |
| | """ |
| | Displays a single progress bar in the terminal with value percent. |
| | :param prefix: The string that will precede the progress bar. |
| | :param percent: The percent progression of the bar (min is 0, max is 100) |
| | """ |
| | BAR_LEN = 20 |
| | percent = min(100, max(0, percent)) |
| | bar_progress = min(int(percent / 100 * BAR_LEN), BAR_LEN) |
| | bar = "|" + "\u2588" * bar_progress + " " * (BAR_LEN - bar_progress) + "|" |
| | str_percent = "%3.0f%%" % percent |
| | print(f"{prefix} : {bar} {str_percent} \r", end="", flush=True) |
| |
|
| |
|
| | def load_remote_manifest(url: str) -> Dict[str, Any]: |
| | """ |
| | Converts a remote yaml file into a Python dictionary |
| | """ |
| | tmp_dir, _ = get_tmp_dirs() |
| | try: |
| | request = urllib.request.urlopen(url, timeout=30) |
| | except urllib.error.HTTPError as e: |
| | e.reason = f"{e.reason} {url}" |
| | raise |
| | manifest_path = os.path.join(tmp_dir, str(uuid.uuid4()) + ".yaml") |
| | with open(manifest_path, "wb") as manifest: |
| | while True: |
| | buffer = request.read(BLOCK_SIZE) |
| | if not buffer: |
| | |
| | break |
| | manifest.write(buffer) |
| | try: |
| | result = load_local_manifest(manifest_path) |
| | finally: |
| | os.remove(manifest_path) |
| | return result |
| |
|
| |
|
| | def load_local_manifest(path: str) -> Dict[str, Any]: |
| | """ |
| | Converts a local yaml file into a Python dictionary |
| | """ |
| | with open(path) as data_file: |
| | return yaml.safe_load(data_file) |
| |
|
| |
|
| | class ZipFileWithProgress(ZipFile): |
| | """ |
| | This is a helper class inheriting from ZipFile that allows to display a progress |
| | bar while the files are being extracted. |
| | """ |
| |
|
| | def extract_zip(self, prefix: str, path: str) -> None: |
| | members = self.namelist() |
| | path = os.fspath(path) |
| | total = len(members) |
| | n = 0 |
| | for zipinfo in members: |
| | self.extract(zipinfo, path, None) |
| | n += 1 |
| | print_progress(prefix, n / total * 100) |
| |
|