Spaces:
Running
on
Zero
Running
on
Zero
| # Copyright (c) CAIRI AI Lab. All rights reserved | |
| import os | |
| import logging | |
| import numpy as np | |
| import torch | |
| import random | |
| import torch.backends.cudnn as cudnn | |
| from collections import OrderedDict | |
| from typing import Tuple | |
| from .config_utils import Config | |
| import torch | |
| import torch.multiprocessing as mp | |
| from torch import distributed as dist | |
| def set_seed(seed): | |
| random.seed(seed) | |
| np.random.seed(seed) | |
| torch.manual_seed(seed) | |
| cudnn.deterministic = True | |
| def print_log(message): | |
| print(message) | |
| logging.info(message) | |
| def output_namespace(namespace): | |
| configs = namespace.__dict__ | |
| message = '' | |
| for k, v in configs.items(): | |
| message += '\n' + k + ': \t' + str(v) + '\t' | |
| return message | |
| def check_dir(path): | |
| if not os.path.exists(path): | |
| os.makedirs(path) | |
| return False | |
| return True | |
| def get_dataset(config): | |
| from src.datasets import load_data | |
| return load_data(**config) | |
| def count_parameters(model): | |
| return sum(p.numel() for p in model.parameters() if p.requires_grad) | |
| def measure_throughput(model, input_dummy): | |
| bs = 100 | |
| repetitions = 100 | |
| if isinstance(input_dummy, tuple): | |
| input_dummy = list(input_dummy) | |
| _, T, C, H, W = input_dummy[0].shape | |
| _input = torch.rand(bs, T, C, H, W).to(input_dummy[0].device) | |
| input_dummy[0] = _input | |
| input_dummy = tuple(input_dummy) | |
| else: | |
| _, T, C, H, W = input_dummy.shape | |
| input_dummy = torch.rand(bs, T, C, H, W).to(input_dummy.device) | |
| total_time = 0 | |
| with torch.no_grad(): | |
| for _ in range(repetitions): | |
| starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True) | |
| starter.record() | |
| if isinstance(input_dummy, tuple): | |
| _ = model(*input_dummy) | |
| else: | |
| _ = model(input_dummy) | |
| ender.record() | |
| torch.cuda.synchronize() | |
| curr_time = starter.elapsed_time(ender) / 1000 | |
| total_time += curr_time | |
| Throughput = (repetitions * bs) / total_time | |
| return Throughput | |
| def load_config(filename:str = None): | |
| """load and print config""" | |
| print('loading config from ' + filename + ' ...') | |
| try: | |
| configfile = Config(filename=filename) | |
| config = configfile._cfg_dict | |
| except (FileNotFoundError, IOError): | |
| config = dict() | |
| print('warning: fail to load the config!') | |
| return config | |
| def update_config(args, config, exclude_keys=list()): | |
| """update the args dict with a new config""" | |
| assert isinstance(args, dict) and isinstance(config, dict) | |
| for k in config.keys(): | |
| if args.get(k, False): | |
| if args[k] != config[k] and k not in exclude_keys: | |
| print(f'overwrite config key -- {k}: {config[k]} -> {args[k]}') | |
| else: | |
| args[k] = config[k] | |
| else: | |
| args[k] = config[k] | |
| return args | |
| def weights_to_cpu(state_dict: OrderedDict) -> OrderedDict: | |
| """Copy a model state_dict to cpu. | |
| Args: | |
| state_dict (OrderedDict): Model weights on GPU. | |
| Returns: | |
| OrderedDict: Model weights on GPU. | |
| """ | |
| state_dict_cpu = OrderedDict() | |
| for key, val in state_dict.items(): | |
| state_dict_cpu[key] = val.cpu() | |
| # Keep metadata in state_dict | |
| state_dict_cpu._metadata = getattr( # type: ignore | |
| state_dict, '_metadata', OrderedDict()) | |
| return state_dict_cpu | |
| def init_dist(launcher: str, backend: str = 'nccl', **kwargs) -> None: | |
| if mp.get_start_method(allow_none=True) is None: | |
| mp.set_start_method('spawn') | |
| if launcher == 'pytorch': | |
| _init_dist_pytorch(backend, **kwargs) | |
| elif launcher == 'mpi': | |
| _init_dist_mpi(backend, **kwargs) | |
| else: | |
| raise ValueError(f'Invalid launcher type: {launcher}') | |
| def _init_dist_pytorch(backend: str, **kwargs) -> None: | |
| # TODO: use local_rank instead of rank % num_gpus | |
| rank = int(os.environ['RANK']) | |
| num_gpus = torch.cuda.device_count() | |
| torch.cuda.set_device(rank % num_gpus) | |
| dist.init_process_group(backend=backend, **kwargs) | |
| def _init_dist_mpi(backend: str, **kwargs) -> None: | |
| local_rank = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK']) | |
| torch.cuda.set_device(local_rank) | |
| if 'MASTER_PORT' not in os.environ: | |
| # 29500 is torch.distributed default port | |
| os.environ['MASTER_PORT'] = '29500' | |
| if 'MASTER_ADDR' not in os.environ: | |
| raise KeyError('The environment variable MASTER_ADDR is not set') | |
| os.environ['WORLD_SIZE'] = os.environ['OMPI_COMM_WORLD_SIZE'] | |
| os.environ['RANK'] = os.environ['OMPI_COMM_WORLD_RANK'] | |
| dist.init_process_group(backend=backend, **kwargs) | |
| def get_dist_info() -> Tuple[int, int]: | |
| if dist.is_available() and dist.is_initialized(): | |
| rank = dist.get_rank() | |
| world_size = dist.get_world_size() | |
| else: | |
| rank = 0 | |
| world_size = 1 | |
| return rank, world_size | |