| | import os |
| | import struct |
| | import logging |
| | import torch |
| | import math |
| | import numpy as np |
| | import random |
| | import yaml |
| | import torch.distributed as dist |
| | import torch.nn.functional as F |
| |
|
| |
|
| | |
| | |
| | def get_logger( |
| | name, |
| | format_str="%(asctime)s [%(pathname)s:%(lineno)s - %(levelname)s ] %(message)s", |
| | date_format="%Y-%m-%d %H:%M:%S", |
| | file=False): |
| | """ |
| | Get python logger instance |
| | """ |
| | logger = logging.getLogger(name) |
| | logger.setLevel(logging.INFO) |
| | |
| | handler = logging.StreamHandler() if not file else logging.FileHandler( |
| | name) |
| | handler.setLevel(logging.INFO) |
| | formatter = logging.Formatter(fmt=format_str, datefmt=date_format) |
| | handler.setFormatter(formatter) |
| | logger.addHandler(handler) |
| | return logger |
| |
|
| |
|
| | |
| | def get_logger_2( |
| | name, |
| | format_str="%(asctime)s [%(pathname)s:%(lineno)s - %(levelname)s ] %(message)s", |
| | date_format="%Y-%m-%d %H:%M:%S"): |
| | logger = logging.getLogger(name) |
| | logger.setLevel(logging.INFO) |
| |
|
| | |
| | c_handler = logging.StreamHandler() |
| | f_handler = logging.FileHandler(name) |
| | c_handler.setLevel(logging.INFO) |
| | f_handler.setLevel(logging.INFO) |
| |
|
| | |
| | c_format = logging.Formatter(fmt=format_str, datefmt=date_format) |
| | f_format = logging.Formatter(fmt=format_str, datefmt=date_format) |
| | c_handler.setFormatter(c_format) |
| | f_handler.setFormatter(f_format) |
| |
|
| | |
| | logger.addHandler(c_handler) |
| | logger.addHandler(f_handler) |
| |
|
| | return logger |
| |
|
| |
|
| | |
| |
|
| | |
| | def getoneNode(): |
| | nodelist = os.environ['SLURM_JOB_NODELIST'] |
| | nodelist = nodelist.strip().split(',')[0] |
| | import re |
| | text = re.split('[-\[\]]', nodelist) |
| | if ('' in text): |
| | text.remove('') |
| | return text[0] + '-' + text[1] + '-' + text[2] |
| |
|
| |
|
| | def dist_init(host_addr, rank, local_rank, world_size, port=23456): |
| | host_addr_full = 'tcp://' + host_addr + ':' + str(port) |
| | dist.init_process_group("nccl", init_method=host_addr_full, |
| | rank=rank, world_size=world_size) |
| | num_gpus = torch.cuda.device_count() |
| | |
| | assert dist.is_initialized() |
| |
|
| |
|
| | def cleanup(): |
| | dist.destroy_process_group() |
| |
|
| |
|
| | def average_gradients(model, world_size): |
| | size = float(world_size) |
| | for param in model.parameters(): |
| | if (param.requires_grad and param.grad is not None): |
| | dist.all_reduce(param.grad.data, op=dist.ReduceOp.SUM) |
| | param.grad.data /= size |
| |
|
| |
|
| | def data_reduce(data): |
| | dist.all_reduce(data, op=dist.ReduceOp.SUM) |
| | return data / torch.distributed.get_world_size() |
| |
|
| |
|
| | |
| |
|
| |
|
| | |
| | def reduce_lr(optimizer, initial_lr, final_lr, current_iter, max_iter, coeff=1.0): |
| | current_lr = coeff * math.exp((current_iter / max_iter) * math.log(final_lr / initial_lr)) * initial_lr |
| | for param_group in optimizer.param_groups: |
| | param_group['lr'] = current_lr |
| |
|
| |
|
| | def get_reduce_lr(initial_lr, final_lr, current_iter, max_iter): |
| | current_lr = math.exp((current_iter / max_iter) * math.log(final_lr / initial_lr)) * initial_lr |
| | return current_lr |
| |
|
| |
|
| | def set_lr(optimizer, lr): |
| | for param_group in optimizer.param_groups: |
| | param_group['lr'] = lr |
| |
|
| | |
| |
|
| | |
| | def parse_config_or_kwargs(config_file, **kwargs): |
| | with open(config_file) as con_read: |
| | yaml_config = yaml.load(con_read, Loader=yaml.FullLoader) |
| | |
| | return dict(yaml_config, **kwargs) |
| |
|
| |
|
| | def store_yaml(config_file, store_path, **kwargs): |
| | with open(config_file, 'r') as f: |
| | config_lines = f.readlines() |
| |
|
| | keys_list = list(kwargs.keys()) |
| | with open(store_path, 'w') as f: |
| | for line in config_lines: |
| | if ':' in line and line.split(':')[0] in keys_list: |
| | key = line.split(':')[0] |
| | line = '{}: {}\n'.format(key, kwargs[key]) |
| | f.write(line) |
| |
|
| |
|
| | |
| |
|
| |
|
| | def check_dir(dir): |
| | if not os.path.exists(dir): |
| | os.mkdir(dir) |
| |
|
| |
|
| | def set_seed(seed=66): |
| | np.random.seed(seed) |
| | random.seed(seed) |
| |
|
| | torch.manual_seed(seed) |
| | torch.cuda.manual_seed(seed) |
| | torch.cuda.manual_seed_all(seed) |
| |
|
| | |
| | |
| |
|
| |
|
| | |
| | |
| | def correct_key(state_dict): |
| | keys = list(state_dict.keys()) |
| | if 'module' not in keys[0]: |
| | return state_dict |
| | else: |
| | new_state_dict = {} |
| | for key in keys: |
| | new_key = '.'.join(key.split('.')[1:]) |
| | new_state_dict[new_key] = state_dict[key] |
| | return new_state_dict |
| |
|
| |
|
| | def validate_path(dir_name): |
| | """ |
| | :param dir_name: Create the directory if it doesn't exist |
| | :return: None |
| | """ |
| | dir_name = os.path.dirname(dir_name) |
| | if not os.path.exists(dir_name) and (dir_name != ''): |
| | os.makedirs(dir_name) |
| |
|
| |
|
| | def get_lr(optimizer): |
| | for param_group in optimizer.param_groups: |
| | return param_group['lr'] |
| |
|