Spaces:
Running
on
Zero
Running
on
Zero
File size: 4,963 Bytes
7968cb0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 |
# Copyright (c) CAIRI AI Lab. All rights reserved
import os
import logging
import numpy as np
import torch
import random
import torch.backends.cudnn as cudnn
from collections import OrderedDict
from typing import Tuple
from .config_utils import Config
import torch
import torch.multiprocessing as mp
from torch import distributed as dist
def set_seed(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
cudnn.deterministic = True
def print_log(message):
print(message)
logging.info(message)
def output_namespace(namespace):
configs = namespace.__dict__
message = ''
for k, v in configs.items():
message += '\n' + k + ': \t' + str(v) + '\t'
return message
def check_dir(path):
if not os.path.exists(path):
os.makedirs(path)
return False
return True
def get_dataset(config):
from src.datasets import load_data
return load_data(**config)
def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
def measure_throughput(model, input_dummy):
bs = 100
repetitions = 100
if isinstance(input_dummy, tuple):
input_dummy = list(input_dummy)
_, T, C, H, W = input_dummy[0].shape
_input = torch.rand(bs, T, C, H, W).to(input_dummy[0].device)
input_dummy[0] = _input
input_dummy = tuple(input_dummy)
else:
_, T, C, H, W = input_dummy.shape
input_dummy = torch.rand(bs, T, C, H, W).to(input_dummy.device)
total_time = 0
with torch.no_grad():
for _ in range(repetitions):
starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
starter.record()
if isinstance(input_dummy, tuple):
_ = model(*input_dummy)
else:
_ = model(input_dummy)
ender.record()
torch.cuda.synchronize()
curr_time = starter.elapsed_time(ender) / 1000
total_time += curr_time
Throughput = (repetitions * bs) / total_time
return Throughput
def load_config(filename:str = None):
"""load and print config"""
print('loading config from ' + filename + ' ...')
try:
configfile = Config(filename=filename)
config = configfile._cfg_dict
except (FileNotFoundError, IOError):
config = dict()
print('warning: fail to load the config!')
return config
def update_config(args, config, exclude_keys=list()):
"""update the args dict with a new config"""
assert isinstance(args, dict) and isinstance(config, dict)
for k in config.keys():
if args.get(k, False):
if args[k] != config[k] and k not in exclude_keys:
print(f'overwrite config key -- {k}: {config[k]} -> {args[k]}')
else:
args[k] = config[k]
else:
args[k] = config[k]
return args
def weights_to_cpu(state_dict: OrderedDict) -> OrderedDict:
"""Copy a model state_dict to cpu.
Args:
state_dict (OrderedDict): Model weights on GPU.
Returns:
OrderedDict: Model weights on GPU.
"""
state_dict_cpu = OrderedDict()
for key, val in state_dict.items():
state_dict_cpu[key] = val.cpu()
# Keep metadata in state_dict
state_dict_cpu._metadata = getattr( # type: ignore
state_dict, '_metadata', OrderedDict())
return state_dict_cpu
def init_dist(launcher: str, backend: str = 'nccl', **kwargs) -> None:
if mp.get_start_method(allow_none=True) is None:
mp.set_start_method('spawn')
if launcher == 'pytorch':
_init_dist_pytorch(backend, **kwargs)
elif launcher == 'mpi':
_init_dist_mpi(backend, **kwargs)
else:
raise ValueError(f'Invalid launcher type: {launcher}')
def _init_dist_pytorch(backend: str, **kwargs) -> None:
# TODO: use local_rank instead of rank % num_gpus
rank = int(os.environ['RANK'])
num_gpus = torch.cuda.device_count()
torch.cuda.set_device(rank % num_gpus)
dist.init_process_group(backend=backend, **kwargs)
def _init_dist_mpi(backend: str, **kwargs) -> None:
local_rank = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK'])
torch.cuda.set_device(local_rank)
if 'MASTER_PORT' not in os.environ:
# 29500 is torch.distributed default port
os.environ['MASTER_PORT'] = '29500'
if 'MASTER_ADDR' not in os.environ:
raise KeyError('The environment variable MASTER_ADDR is not set')
os.environ['WORLD_SIZE'] = os.environ['OMPI_COMM_WORLD_SIZE']
os.environ['RANK'] = os.environ['OMPI_COMM_WORLD_RANK']
dist.init_process_group(backend=backend, **kwargs)
def get_dist_info() -> Tuple[int, int]:
if dist.is_available() and dist.is_initialized():
rank = dist.get_rank()
world_size = dist.get_world_size()
else:
rank = 0
world_size = 1
return rank, world_size
|