Honzus24's picture
initial commit
7968cb0
# Copyright (c) CAIRI AI Lab. All rights reserved
import os
import logging
import numpy as np
import torch
import random
import torch.backends.cudnn as cudnn
from collections import OrderedDict
from typing import Tuple
from .config_utils import Config
import torch
import torch.multiprocessing as mp
from torch import distributed as dist
def set_seed(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
cudnn.deterministic = True
def print_log(message):
print(message)
logging.info(message)
def output_namespace(namespace):
configs = namespace.__dict__
message = ''
for k, v in configs.items():
message += '\n' + k + ': \t' + str(v) + '\t'
return message
def check_dir(path):
if not os.path.exists(path):
os.makedirs(path)
return False
return True
def get_dataset(config):
from src.datasets import load_data
return load_data(**config)
def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
def measure_throughput(model, input_dummy):
bs = 100
repetitions = 100
if isinstance(input_dummy, tuple):
input_dummy = list(input_dummy)
_, T, C, H, W = input_dummy[0].shape
_input = torch.rand(bs, T, C, H, W).to(input_dummy[0].device)
input_dummy[0] = _input
input_dummy = tuple(input_dummy)
else:
_, T, C, H, W = input_dummy.shape
input_dummy = torch.rand(bs, T, C, H, W).to(input_dummy.device)
total_time = 0
with torch.no_grad():
for _ in range(repetitions):
starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
starter.record()
if isinstance(input_dummy, tuple):
_ = model(*input_dummy)
else:
_ = model(input_dummy)
ender.record()
torch.cuda.synchronize()
curr_time = starter.elapsed_time(ender) / 1000
total_time += curr_time
Throughput = (repetitions * bs) / total_time
return Throughput
def load_config(filename:str = None):
"""load and print config"""
print('loading config from ' + filename + ' ...')
try:
configfile = Config(filename=filename)
config = configfile._cfg_dict
except (FileNotFoundError, IOError):
config = dict()
print('warning: fail to load the config!')
return config
def update_config(args, config, exclude_keys=list()):
"""update the args dict with a new config"""
assert isinstance(args, dict) and isinstance(config, dict)
for k in config.keys():
if args.get(k, False):
if args[k] != config[k] and k not in exclude_keys:
print(f'overwrite config key -- {k}: {config[k]} -> {args[k]}')
else:
args[k] = config[k]
else:
args[k] = config[k]
return args
def weights_to_cpu(state_dict: OrderedDict) -> OrderedDict:
"""Copy a model state_dict to cpu.
Args:
state_dict (OrderedDict): Model weights on GPU.
Returns:
OrderedDict: Model weights on GPU.
"""
state_dict_cpu = OrderedDict()
for key, val in state_dict.items():
state_dict_cpu[key] = val.cpu()
# Keep metadata in state_dict
state_dict_cpu._metadata = getattr( # type: ignore
state_dict, '_metadata', OrderedDict())
return state_dict_cpu
def init_dist(launcher: str, backend: str = 'nccl', **kwargs) -> None:
if mp.get_start_method(allow_none=True) is None:
mp.set_start_method('spawn')
if launcher == 'pytorch':
_init_dist_pytorch(backend, **kwargs)
elif launcher == 'mpi':
_init_dist_mpi(backend, **kwargs)
else:
raise ValueError(f'Invalid launcher type: {launcher}')
def _init_dist_pytorch(backend: str, **kwargs) -> None:
# TODO: use local_rank instead of rank % num_gpus
rank = int(os.environ['RANK'])
num_gpus = torch.cuda.device_count()
torch.cuda.set_device(rank % num_gpus)
dist.init_process_group(backend=backend, **kwargs)
def _init_dist_mpi(backend: str, **kwargs) -> None:
local_rank = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK'])
torch.cuda.set_device(local_rank)
if 'MASTER_PORT' not in os.environ:
# 29500 is torch.distributed default port
os.environ['MASTER_PORT'] = '29500'
if 'MASTER_ADDR' not in os.environ:
raise KeyError('The environment variable MASTER_ADDR is not set')
os.environ['WORLD_SIZE'] = os.environ['OMPI_COMM_WORLD_SIZE']
os.environ['RANK'] = os.environ['OMPI_COMM_WORLD_RANK']
dist.init_process_group(backend=backend, **kwargs)
def get_dist_info() -> Tuple[int, int]:
if dist.is_available() and dist.is_initialized():
rank = dist.get_rank()
world_size = dist.get_world_size()
else:
rank = 0
world_size = 1
return rank, world_size