import yaml import torch import logging from pathlib import Path from easydict import EasyDict import torch.distributed as dist import torch.multiprocessing as mp def cfg_from_yaml_file(cfg_file): with open(cfg_file, 'r') as f: try: new_config = yaml.load(f, Loader=yaml.FullLoader) except: new_config = yaml.load(f) cfg = EasyDict(new_config) cfg.ROOT_DIR = (Path(__file__).resolve().parent / '../').resolve() return cfg def log_config_to_file(cfg, pre='cfg', logger=None): for key, val in cfg.items(): if isinstance(cfg[key], EasyDict): logger.info('\n%s.%s = edict()' % (pre, key)) log_config_to_file(cfg[key], pre=pre + '.' + key, logger=logger) continue logger.info('%s.%s: %s' % (pre, key, val)) def init_dist_pytorch(batch_size, local_rank, backend='nccl'): if mp.get_start_method(allow_none=True) is None: mp.set_start_method('spawn') num_gpus = torch.cuda.device_count() torch.cuda.set_device(local_rank % num_gpus) dist.init_process_group(backend=backend) assert batch_size % num_gpus == 0, 'Batch size should be matched with GPUS: (%d, %d)' % (batch_size, num_gpus) batch_size_each_gpu = batch_size // num_gpus rank = dist.get_rank() return batch_size_each_gpu, rank def get_dist_info(): if torch.__version__ < '1.0': initialized = dist._initialized else: if dist.is_available(): initialized = dist.is_initialized() else: initialized = False if initialized: rank = dist.get_rank() world_size = dist.get_world_size() else: rank = 0 world_size = 1 return rank, world_size def create_logger(log_file=None, log_level=logging.INFO): logger = logging.getLogger(__name__) logger.setLevel(log_level) formatter = logging.Formatter('%(asctime)s %(levelname)5s %(message)s') console = logging.StreamHandler() console.setLevel(log_level) console.setFormatter(formatter) logger.addHandler(console) if log_file is not None: file_handler = logging.FileHandler(filename=log_file) file_handler.setLevel(log_level) file_handler.setFormatter(formatter) logger.addHandler(file_handler) return logger