Spaces:
Runtime error
Runtime error
| import os | |
| import pprint | |
| import shutil | |
| import sys | |
| from datetime import datetime | |
| from pathlib import Path | |
| import torch | |
| import yaml | |
| from easydict import EasyDict as edict | |
| from .distributed import get_world_size, synchronize | |
| from .log import add_logging, logger | |
| def init_experiment(args, model_name): | |
| model_path = Path(args.model_path) | |
| ftree = get_model_family_tree(model_path, model_name=model_name) | |
| if ftree is None: | |
| print( | |
| 'Models can only be located in the "models" directory in the root of the repository' | |
| ) | |
| sys.exit(1) | |
| cfg = load_config(model_path) | |
| update_config(cfg, args) | |
| cfg.distributed = args.distributed | |
| cfg.local_rank = args.local_rank | |
| if cfg.distributed: | |
| torch.distributed.init_process_group(backend="nccl", init_method="env://") | |
| if args.workers > 0: | |
| torch.multiprocessing.set_start_method("forkserver", force=True) | |
| experiments_path = Path(cfg.EXPS_PATH) | |
| exp_parent_path = experiments_path / "/".join(ftree) | |
| exp_parent_path.mkdir(parents=True, exist_ok=True) | |
| if cfg.resume_exp: | |
| exp_path = find_resume_exp(exp_parent_path, cfg.resume_exp) | |
| else: | |
| last_exp_indx = find_last_exp_indx(exp_parent_path) | |
| exp_name = f"{last_exp_indx:03d}" | |
| if cfg.exp_name: | |
| exp_name += "_" + cfg.exp_name | |
| exp_path = exp_parent_path / exp_name | |
| synchronize() | |
| if cfg.local_rank == 0: | |
| exp_path.mkdir(parents=True) | |
| cfg.EXP_PATH = exp_path | |
| cfg.CHECKPOINTS_PATH = exp_path / "checkpoints" | |
| cfg.VIS_PATH = exp_path / "vis" | |
| cfg.LOGS_PATH = exp_path / "logs" | |
| if cfg.local_rank == 0: | |
| cfg.LOGS_PATH.mkdir(exist_ok=True) | |
| cfg.CHECKPOINTS_PATH.mkdir(exist_ok=True) | |
| cfg.VIS_PATH.mkdir(exist_ok=True) | |
| dst_script_path = exp_path / ( | |
| model_path.stem | |
| + datetime.strftime(datetime.today(), "_%Y-%m-%d-%H-%M-%S.py") | |
| ) | |
| if args.temp_model_path: | |
| shutil.copy(args.temp_model_path, dst_script_path) | |
| os.remove(args.temp_model_path) | |
| else: | |
| shutil.copy(model_path, dst_script_path) | |
| synchronize() | |
| if cfg.gpus != "": | |
| gpu_ids = [int(id) for id in cfg.gpus.split(",")] | |
| else: | |
| gpu_ids = list(range(max(cfg.ngpus, get_world_size()))) | |
| cfg.gpus = ",".join([str(id) for id in gpu_ids]) | |
| cfg.gpu_ids = gpu_ids | |
| cfg.ngpus = len(gpu_ids) | |
| cfg.multi_gpu = cfg.ngpus > 1 | |
| if cfg.distributed: | |
| cfg.device = torch.device("cuda") | |
| cfg.gpu_ids = [cfg.gpu_ids[cfg.local_rank]] | |
| torch.cuda.set_device(cfg.gpu_ids[0]) | |
| else: | |
| if cfg.multi_gpu: | |
| os.environ["CUDA_VISIBLE_DEVICES"] = cfg.gpus | |
| ngpus = torch.cuda.device_count() | |
| assert ngpus == cfg.ngpus | |
| cfg.device = torch.device(f"cuda:{cfg.gpu_ids[0]}") | |
| if cfg.local_rank == 0: | |
| add_logging(cfg.LOGS_PATH, prefix="train_") | |
| logger.info(f"Number of GPUs: {cfg.ngpus}") | |
| if cfg.distributed: | |
| logger.info(f"Multi-Process Multi-GPU Distributed Training") | |
| logger.info("Run experiment with config:") | |
| logger.info(pprint.pformat(cfg, indent=4)) | |
| return cfg | |
| def get_model_family_tree(model_path, terminate_name="models", model_name=None): | |
| if model_name is None: | |
| model_name = model_path.stem | |
| family_tree = [model_name] | |
| for x in model_path.parents: | |
| if x.stem == terminate_name: | |
| break | |
| family_tree.append(x.stem) | |
| else: | |
| return None | |
| return family_tree[::-1] | |
| def find_last_exp_indx(exp_parent_path): | |
| indx = 0 | |
| for x in exp_parent_path.iterdir(): | |
| if not x.is_dir(): | |
| continue | |
| exp_name = x.stem | |
| if exp_name[:3].isnumeric(): | |
| indx = max(indx, int(exp_name[:3]) + 1) | |
| return indx | |
| def find_resume_exp(exp_parent_path, exp_pattern): | |
| candidates = sorted(exp_parent_path.glob(f"{exp_pattern}*")) | |
| if len(candidates) == 0: | |
| print( | |
| f'No experiments could be found that satisfies the pattern = "*{exp_pattern}"' | |
| ) | |
| sys.exit(1) | |
| elif len(candidates) > 1: | |
| print("More than one experiment found:") | |
| for x in candidates: | |
| print(x) | |
| sys.exit(1) | |
| else: | |
| exp_path = candidates[0] | |
| print(f'Continue with experiment "{exp_path}"') | |
| return exp_path | |
| def update_config(cfg, args): | |
| for param_name, value in vars(args).items(): | |
| if param_name.lower() in cfg or param_name.upper() in cfg: | |
| continue | |
| cfg[param_name] = value | |
| def load_config(model_path): | |
| model_name = model_path.stem | |
| config_path = model_path.parent / (model_name + ".yml") | |
| if config_path.exists(): | |
| cfg = load_config_file(config_path) | |
| else: | |
| cfg = dict() | |
| cwd = Path.cwd() | |
| config_parent = config_path.parent.absolute() | |
| while len(config_parent.parents) > 0: | |
| config_path = config_parent / "config.yml" | |
| if config_path.exists(): | |
| local_config = load_config_file(config_path, model_name=model_name) | |
| cfg.update({k: v for k, v in local_config.items() if k not in cfg}) | |
| if config_parent.absolute() == cwd: | |
| break | |
| config_parent = config_parent.parent | |
| return edict(cfg) | |
| def load_config_file(config_path, model_name=None, return_edict=False): | |
| with open(config_path, "r") as f: | |
| cfg = yaml.safe_load(f) | |
| if "SUBCONFIGS" in cfg: | |
| if model_name is not None and model_name in cfg["SUBCONFIGS"]: | |
| cfg.update(cfg["SUBCONFIGS"][model_name]) | |
| del cfg["SUBCONFIGS"] | |
| return edict(cfg) if return_edict else cfg | |