# Copyright (c) CAIRI AI Lab. All rights reserved import os import logging import numpy as np import torch import random import torch.backends.cudnn as cudnn from collections import OrderedDict from typing import Tuple from .config_utils import Config import torch import torch.multiprocessing as mp from torch import distributed as dist def set_seed(seed): random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) cudnn.deterministic = True def print_log(message): print(message) logging.info(message) def output_namespace(namespace): configs = namespace.__dict__ message = '' for k, v in configs.items(): message += '\n' + k + ': \t' + str(v) + '\t' return message def check_dir(path): if not os.path.exists(path): os.makedirs(path) return False return True def get_dataset(config): from src.datasets import load_data return load_data(**config) def count_parameters(model): return sum(p.numel() for p in model.parameters() if p.requires_grad) def measure_throughput(model, input_dummy): bs = 100 repetitions = 100 if isinstance(input_dummy, tuple): input_dummy = list(input_dummy) _, T, C, H, W = input_dummy[0].shape _input = torch.rand(bs, T, C, H, W).to(input_dummy[0].device) input_dummy[0] = _input input_dummy = tuple(input_dummy) else: _, T, C, H, W = input_dummy.shape input_dummy = torch.rand(bs, T, C, H, W).to(input_dummy.device) total_time = 0 with torch.no_grad(): for _ in range(repetitions): starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True) starter.record() if isinstance(input_dummy, tuple): _ = model(*input_dummy) else: _ = model(input_dummy) ender.record() torch.cuda.synchronize() curr_time = starter.elapsed_time(ender) / 1000 total_time += curr_time Throughput = (repetitions * bs) / total_time return Throughput def load_config(filename:str = None): """load and print config""" print('loading config from ' + filename + ' ...') try: configfile = Config(filename=filename) config = configfile._cfg_dict except (FileNotFoundError, IOError): config = dict() print('warning: fail to load the config!') return config def update_config(args, config, exclude_keys=list()): """update the args dict with a new config""" assert isinstance(args, dict) and isinstance(config, dict) for k in config.keys(): if args.get(k, False): if args[k] != config[k] and k not in exclude_keys: print(f'overwrite config key -- {k}: {config[k]} -> {args[k]}') else: args[k] = config[k] else: args[k] = config[k] return args def weights_to_cpu(state_dict: OrderedDict) -> OrderedDict: """Copy a model state_dict to cpu. Args: state_dict (OrderedDict): Model weights on GPU. Returns: OrderedDict: Model weights on GPU. """ state_dict_cpu = OrderedDict() for key, val in state_dict.items(): state_dict_cpu[key] = val.cpu() # Keep metadata in state_dict state_dict_cpu._metadata = getattr( # type: ignore state_dict, '_metadata', OrderedDict()) return state_dict_cpu def init_dist(launcher: str, backend: str = 'nccl', **kwargs) -> None: if mp.get_start_method(allow_none=True) is None: mp.set_start_method('spawn') if launcher == 'pytorch': _init_dist_pytorch(backend, **kwargs) elif launcher == 'mpi': _init_dist_mpi(backend, **kwargs) else: raise ValueError(f'Invalid launcher type: {launcher}') def _init_dist_pytorch(backend: str, **kwargs) -> None: # TODO: use local_rank instead of rank % num_gpus rank = int(os.environ['RANK']) num_gpus = torch.cuda.device_count() torch.cuda.set_device(rank % num_gpus) dist.init_process_group(backend=backend, **kwargs) def _init_dist_mpi(backend: str, **kwargs) -> None: local_rank = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK']) torch.cuda.set_device(local_rank) if 'MASTER_PORT' not in os.environ: # 29500 is torch.distributed default port os.environ['MASTER_PORT'] = '29500' if 'MASTER_ADDR' not in os.environ: raise KeyError('The environment variable MASTER_ADDR is not set') os.environ['WORLD_SIZE'] = os.environ['OMPI_COMM_WORLD_SIZE'] os.environ['RANK'] = os.environ['OMPI_COMM_WORLD_RANK'] dist.init_process_group(backend=backend, **kwargs) def get_dist_info() -> Tuple[int, int]: if dist.is_available() and dist.is_initialized(): rank = dist.get_rank() world_size = dist.get_world_size() else: rank = 0 world_size = 1 return rank, world_size