Spaces:
Runtime error
Runtime error
| # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. | |
| # | |
| # This work is licensed under the Creative Commons Attribution-NonCommercial | |
| # 4.0 International License. To view a copy of this license, visit | |
| # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to | |
| # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. | |
| """Submit a function to be run either locally or in a computing cluster.""" | |
| import copy | |
| import io | |
| import os | |
| import pathlib | |
| import pickle | |
| import platform | |
| import pprint | |
| import re | |
| import shutil | |
| import time | |
| import traceback | |
| import zipfile | |
| from enum import Enum | |
| from .. import util | |
| from ..util import EasyDict | |
| class SubmitTarget(Enum): | |
| """The target where the function should be run. | |
| LOCAL: Run it locally. | |
| """ | |
| LOCAL = 1 | |
| class PathType(Enum): | |
| """Determines in which format should a path be formatted. | |
| WINDOWS: Format with Windows style. | |
| LINUX: Format with Linux/Posix style. | |
| AUTO: Use current OS type to select either WINDOWS or LINUX. | |
| """ | |
| WINDOWS = 1 | |
| LINUX = 2 | |
| AUTO = 3 | |
| _user_name_override = None | |
| class SubmitConfig(util.EasyDict): | |
| """Strongly typed config dict needed to submit runs. | |
| Attributes: | |
| run_dir_root: Path to the run dir root. Can be optionally templated with tags. Needs to always be run through get_path_from_template. | |
| run_desc: Description of the run. Will be used in the run dir and task name. | |
| run_dir_ignore: List of file patterns used to ignore files when copying files to the run dir. | |
| run_dir_extra_files: List of (abs_path, rel_path) tuples of file paths. rel_path root will be the src directory inside the run dir. | |
| submit_target: Submit target enum value. Used to select where the run is actually launched. | |
| num_gpus: Number of GPUs used/requested for the run. | |
| print_info: Whether to print debug information when submitting. | |
| ask_confirmation: Whether to ask a confirmation before submitting. | |
| run_id: Automatically populated value during submit. | |
| run_name: Automatically populated value during submit. | |
| run_dir: Automatically populated value during submit. | |
| run_func_name: Automatically populated value during submit. | |
| run_func_kwargs: Automatically populated value during submit. | |
| user_name: Automatically populated value during submit. Can be set by the user which will then override the automatic value. | |
| task_name: Automatically populated value during submit. | |
| host_name: Automatically populated value during submit. | |
| """ | |
| def __init__(self): | |
| super().__init__() | |
| # run (set these) | |
| self.run_dir_root = "" # should always be passed through get_path_from_template | |
| self.run_desc = "" | |
| self.run_dir_ignore = ["__pycache__", "*.pyproj", "*.sln", "*.suo", ".cache", ".idea", ".vs", ".vscode"] | |
| self.run_dir_extra_files = None | |
| # submit (set these) | |
| self.submit_target = SubmitTarget.LOCAL | |
| self.num_gpus = 1 | |
| self.print_info = False | |
| self.ask_confirmation = False | |
| # (automatically populated) | |
| self.run_id = None | |
| self.run_name = None | |
| self.run_dir = None | |
| self.run_func_name = None | |
| self.run_func_kwargs = None | |
| self.user_name = None | |
| self.task_name = None | |
| self.host_name = "localhost" | |
| def get_path_from_template(path_template: str, path_type: PathType = PathType.AUTO) -> str: | |
| """Replace tags in the given path template and return either Windows or Linux formatted path.""" | |
| # automatically select path type depending on running OS | |
| if path_type == PathType.AUTO: | |
| if platform.system() == "Windows": | |
| path_type = PathType.WINDOWS | |
| elif platform.system() == "Linux": | |
| path_type = PathType.LINUX | |
| else: | |
| raise RuntimeError("Unknown platform") | |
| path_template = path_template.replace("<USERNAME>", get_user_name()) | |
| # return correctly formatted path | |
| if path_type == PathType.WINDOWS: | |
| return str(pathlib.PureWindowsPath(path_template)) | |
| elif path_type == PathType.LINUX: | |
| return str(pathlib.PurePosixPath(path_template)) | |
| else: | |
| raise RuntimeError("Unknown platform") | |
| def get_template_from_path(path: str) -> str: | |
| """Convert a normal path back to its template representation.""" | |
| # replace all path parts with the template tags | |
| path = path.replace("\\", "/") | |
| return path | |
| def convert_path(path: str, path_type: PathType = PathType.AUTO) -> str: | |
| """Convert a normal path to template and the convert it back to a normal path with given path type.""" | |
| path_template = get_template_from_path(path) | |
| path = get_path_from_template(path_template, path_type) | |
| return path | |
| def set_user_name_override(name: str) -> None: | |
| """Set the global username override value.""" | |
| global _user_name_override | |
| _user_name_override = name | |
| def get_user_name(): | |
| """Get the current user name.""" | |
| if _user_name_override is not None: | |
| return _user_name_override | |
| elif platform.system() == "Windows": | |
| return os.getlogin() | |
| elif platform.system() == "Linux": | |
| try: | |
| import pwd # pylint: disable=import-error | |
| return pwd.getpwuid(os.geteuid()).pw_name # pylint: disable=no-member | |
| except: | |
| return "unknown" | |
| else: | |
| raise RuntimeError("Unknown platform") | |
| def _create_run_dir_local(submit_config: SubmitConfig) -> str: | |
| """Create a new run dir with increasing ID number at the start.""" | |
| run_dir_root = get_path_from_template(submit_config.run_dir_root, PathType.AUTO) | |
| if not os.path.exists(run_dir_root): | |
| print("Creating the run dir root: {}".format(run_dir_root)) | |
| os.makedirs(run_dir_root) | |
| submit_config.run_id = _get_next_run_id_local(run_dir_root) | |
| submit_config.run_name = "{0:05d}-{1}".format(submit_config.run_id, submit_config.run_desc) | |
| run_dir = os.path.join(run_dir_root, submit_config.run_name) | |
| if os.path.exists(run_dir): | |
| raise RuntimeError("The run dir already exists! ({0})".format(run_dir)) | |
| print("Creating the run dir: {}".format(run_dir)) | |
| os.makedirs(run_dir) | |
| return run_dir | |
| def _get_next_run_id_local(run_dir_root: str) -> int: | |
| """Reads all directory names in a given directory (non-recursive) and returns the next (increasing) run id. Assumes IDs are numbers at the start of the directory names.""" | |
| dir_names = [d for d in os.listdir(run_dir_root) if os.path.isdir(os.path.join(run_dir_root, d))] | |
| r = re.compile("^\\d+") # match one or more digits at the start of the string | |
| run_id = 0 | |
| for dir_name in dir_names: | |
| m = r.match(dir_name) | |
| if m is not None: | |
| i = int(m.group()) | |
| run_id = max(run_id, i + 1) | |
| return run_id | |
| def _populate_run_dir(run_dir: str, submit_config: SubmitConfig) -> None: | |
| """Copy all necessary files into the run dir. Assumes that the dir exists, is local, and is writable.""" | |
| print("Copying files to the run dir") | |
| files = [] | |
| run_func_module_dir_path = util.get_module_dir_by_obj_name(submit_config.run_func_name) | |
| assert '.' in submit_config.run_func_name | |
| for _idx in range(submit_config.run_func_name.count('.') - 1): | |
| run_func_module_dir_path = os.path.dirname(run_func_module_dir_path) | |
| files += util.list_dir_recursively_with_ignore(run_func_module_dir_path, ignores=submit_config.run_dir_ignore, add_base_to_relative=False) | |
| dnnlib_module_dir_path = util.get_module_dir_by_obj_name("dnnlib") | |
| files += util.list_dir_recursively_with_ignore(dnnlib_module_dir_path, ignores=submit_config.run_dir_ignore, add_base_to_relative=True) | |
| if submit_config.run_dir_extra_files is not None: | |
| files += submit_config.run_dir_extra_files | |
| files = [(f[0], os.path.join(run_dir, "src", f[1])) for f in files] | |
| files += [(os.path.join(dnnlib_module_dir_path, "submission", "_internal", "run.py"), os.path.join(run_dir, "run.py"))] | |
| util.copy_files_and_create_dirs(files) | |
| pickle.dump(submit_config, open(os.path.join(run_dir, "submit_config.pkl"), "wb")) | |
| with open(os.path.join(run_dir, "submit_config.txt"), "w") as f: | |
| pprint.pprint(submit_config, stream=f, indent=4, width=200, compact=False) | |
| def run_wrapper(submit_config: SubmitConfig) -> None: | |
| """Wrap the actual run function call for handling logging, exceptions, typing, etc.""" | |
| is_local = submit_config.submit_target == SubmitTarget.LOCAL | |
| checker = None | |
| # when running locally, redirect stderr to stdout, log stdout to a file, and force flushing | |
| if is_local: | |
| logger = util.Logger(file_name=os.path.join(submit_config.run_dir, "log.txt"), file_mode="w", should_flush=True) | |
| else: # when running in a cluster, redirect stderr to stdout, and just force flushing (log writing is handled by run.sh) | |
| logger = util.Logger(file_name=None, should_flush=True) | |
| import dnnlib | |
| dnnlib.submit_config = submit_config | |
| try: | |
| print("dnnlib: Running {0}() on {1}...".format(submit_config.run_func_name, submit_config.host_name)) | |
| start_time = time.time() | |
| util.call_func_by_name(func_name=submit_config.run_func_name, submit_config=submit_config, **submit_config.run_func_kwargs) | |
| print("dnnlib: Finished {0}() in {1}.".format(submit_config.run_func_name, util.format_time(time.time() - start_time))) | |
| except: | |
| if is_local: | |
| raise | |
| else: | |
| traceback.print_exc() | |
| log_src = os.path.join(submit_config.run_dir, "log.txt") | |
| log_dst = os.path.join(get_path_from_template(submit_config.run_dir_root), "{0}-error.txt".format(submit_config.run_name)) | |
| shutil.copyfile(log_src, log_dst) | |
| finally: | |
| open(os.path.join(submit_config.run_dir, "_finished.txt"), "w").close() | |
| dnnlib.submit_config = None | |
| logger.close() | |
| if checker is not None: | |
| checker.stop() | |
| def submit_run(submit_config: SubmitConfig, run_func_name: str, **run_func_kwargs) -> None: | |
| """Create a run dir, gather files related to the run, copy files to the run dir, and launch the run in appropriate place.""" | |
| submit_config = copy.copy(submit_config) | |
| if submit_config.user_name is None: | |
| submit_config.user_name = get_user_name() | |
| submit_config.run_func_name = run_func_name | |
| submit_config.run_func_kwargs = run_func_kwargs | |
| assert submit_config.submit_target == SubmitTarget.LOCAL | |
| if submit_config.submit_target in {SubmitTarget.LOCAL}: | |
| run_dir = _create_run_dir_local(submit_config) | |
| submit_config.task_name = "{0}-{1:05d}-{2}".format(submit_config.user_name, submit_config.run_id, submit_config.run_desc) | |
| submit_config.run_dir = run_dir | |
| _populate_run_dir(run_dir, submit_config) | |
| if submit_config.print_info: | |
| print("\nSubmit config:\n") | |
| pprint.pprint(submit_config, indent=4, width=200, compact=False) | |
| print() | |
| if submit_config.ask_confirmation: | |
| if not util.ask_yes_no("Continue submitting the job?"): | |
| return | |
| run_wrapper(submit_config) | |