| |
|
|
| |
|
|
| |
| |
| |
| |
| |
| |
|
|
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
|
|
| |
|
|
| from zoedepth.utils.misc import count_parameters, parallelize |
| from zoedepth.utils.config import get_config |
| from zoedepth.utils.arg_utils import parse_unknown |
| from zoedepth.trainers.builder import get_trainer |
| from zoedepth.models.builder import build_model |
| from zoedepth.data.data_mono import MixedNYUKITTI |
| import torch.utils.data.distributed |
| import torch.multiprocessing as mp |
| import torch |
| import numpy as np |
| from pprint import pprint |
| import argparse |
| import os |
|
|
| os.environ["PYOPENGL_PLATFORM"] = "egl" |
| os.environ["WANDB_START_METHOD"] = "thread" |
|
|
|
|
| def fix_random_seed(seed: int): |
| """ |
| Fix random seed for reproducibility |
| |
| Args: |
| seed (int): random seed |
| """ |
| import random |
|
|
| import numpy |
| import torch |
|
|
| random.seed(seed) |
| numpy.random.seed(seed) |
| torch.manual_seed(seed) |
| torch.cuda.manual_seed(seed) |
| torch.cuda.manual_seed_all(seed) |
|
|
| torch.backends.cudnn.deterministic = True |
| torch.backends.cudnn.benchmark = False |
|
|
|
|
| def load_ckpt(config, model, checkpoint_dir="./checkpoints", ckpt_type="best"): |
| import glob |
| import os |
|
|
| from zoedepth.models.model_io import load_wts |
|
|
| if hasattr(config, "checkpoint"): |
| checkpoint = config.checkpoint |
| elif hasattr(config, "ckpt_pattern"): |
| pattern = config.ckpt_pattern |
| matches = glob.glob(os.path.join( |
| checkpoint_dir, f"*{pattern}*{ckpt_type}*")) |
| if not (len(matches) > 0): |
| raise ValueError(f"No matches found for the pattern {pattern}") |
|
|
| checkpoint = matches[0] |
|
|
| else: |
| return model |
| model = load_wts(model, checkpoint) |
| print("Loaded weights from {0}".format(checkpoint)) |
| return model |
|
|
|
|
| def main_worker(gpu, ngpus_per_node, config): |
| try: |
| fix_random_seed(43) |
|
|
| config.gpu = gpu |
|
|
| model = build_model(config) |
| |
| |
| |
| model = load_ckpt(config, model) |
| model = parallelize(config, model) |
|
|
| total_params = f"{round(count_parameters(model)/1e6,2)}M" |
| config.total_params = total_params |
| print(f"Total parameters : {total_params}") |
|
|
| train_loader = MixedNYUKITTI(config, "train").data |
| test_loader = MixedNYUKITTI(config, "online_eval").data |
|
|
| trainer = get_trainer(config)( |
| config, model, train_loader, test_loader, device=config.gpu) |
|
|
| trainer.train() |
| finally: |
| import wandb |
| wandb.finish() |
|
|
|
|
| if __name__ == '__main__': |
| mp.set_start_method('forkserver') |
|
|
| parser = argparse.ArgumentParser() |
| parser.add_argument("-m", "--model", type=str, default="synunet") |
| parser.add_argument("-d", "--dataset", type=str, default='mix') |
| parser.add_argument("--trainer", type=str, default=None) |
|
|
| args, unknown_args = parser.parse_known_args() |
| overwrite_kwargs = parse_unknown(unknown_args) |
|
|
| overwrite_kwargs["model"] = args.model |
| if args.trainer is not None: |
| overwrite_kwargs["trainer"] = args.trainer |
|
|
| config = get_config(args.model, "train", args.dataset, **overwrite_kwargs) |
| |
| if config.use_shared_dict: |
| shared_dict = mp.Manager().dict() |
| else: |
| shared_dict = None |
| config.shared_dict = shared_dict |
|
|
| config.batch_size = config.bs |
| config.mode = 'train' |
| if config.root != "." and not os.path.isdir(config.root): |
| os.makedirs(config.root) |
|
|
| try: |
| node_str = os.environ['SLURM_JOB_NODELIST'].replace( |
| '[', '').replace(']', '') |
| nodes = node_str.split(',') |
|
|
| config.world_size = len(nodes) |
| config.rank = int(os.environ['SLURM_PROCID']) |
| |
|
|
| except KeyError as e: |
| |
| config.world_size = 1 |
| config.rank = 0 |
| nodes = ["127.0.0.1"] |
|
|
| if config.distributed: |
|
|
| print(config.rank) |
| port = np.random.randint(15000, 15025) |
| config.dist_url = 'tcp://{}:{}'.format(nodes[0], port) |
| print(config.dist_url) |
| config.dist_backend = 'nccl' |
| config.gpu = None |
|
|
| ngpus_per_node = torch.cuda.device_count() |
| config.num_workers = config.workers |
| config.ngpus_per_node = ngpus_per_node |
| print("Config:") |
| pprint(config) |
| if config.distributed: |
| config.world_size = ngpus_per_node * config.world_size |
| mp.spawn(main_worker, nprocs=ngpus_per_node, |
| args=(ngpus_per_node, config)) |
| else: |
| if ngpus_per_node == 1: |
| config.gpu = 0 |
| main_worker(config.gpu, ngpus_per_node, config) |
|
|