| import importlib |
| import torch |
| import torch.distributed as dist |
| from .avg_meter import AverageMeter |
| from collections import defaultdict, OrderedDict |
| import os |
| import socket |
| from mmcv.utils import collect_env as collect_base_env |
| try: |
| from mmcv.utils import get_git_hash |
| except: |
| from mmengine.utils import get_git_hash |
| |
| |
| import time |
| import datetime |
| import logging |
|
|
|
|
| def main_process() -> bool: |
| return get_rank() == 0 |
| |
| |
|
|
| def get_world_size() -> int: |
| if not dist.is_available(): |
| return 1 |
| if not dist.is_initialized(): |
| return 1 |
| return dist.get_world_size() |
|
|
| def get_rank() -> int: |
| if not dist.is_available(): |
| return 0 |
| if not dist.is_initialized(): |
| return 0 |
| return dist.get_rank() |
|
|
| def _find_free_port(): |
| |
| sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) |
| |
| sock.bind(('', 0)) |
| port = sock.getsockname()[1] |
| sock.close() |
| |
| return port |
|
|
| def _is_free_port(port): |
| ips = socket.gethostbyname_ex(socket.gethostname())[-1] |
| ips.append('localhost') |
| with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: |
| return all(s.connect_ex((ip, port)) != 0 for ip in ips) |
|
|
|
|
| |
| |
| |
| |
|
|
| |
|
|
| def init_env(launcher, cfg): |
| """Initialize distributed training environment. |
| If argument ``cfg.dist_params.dist_url`` is specified as 'env://', then the master port will be system |
| environment variable ``MASTER_PORT``. If ``MASTER_PORT`` is not in system |
| environment variable, then a default port ``29500`` will be used. |
| """ |
| if launcher == 'slurm': |
| _init_dist_slurm(cfg) |
| elif launcher == 'ror': |
| _init_dist_ror(cfg) |
| elif launcher == 'None': |
| _init_none_dist(cfg) |
| else: |
| raise RuntimeError(f'{cfg.launcher} has not been supported!') |
|
|
| def _init_none_dist(cfg): |
| cfg.dist_params.num_gpus_per_node = 1 |
| cfg.dist_params.world_size = 1 |
| cfg.dist_params.nnodes = 1 |
| cfg.dist_params.node_rank = 0 |
| cfg.dist_params.global_rank = 0 |
| cfg.dist_params.local_rank = 0 |
| os.environ["WORLD_SIZE"] = str(1) |
|
|
| def _init_dist_ror(cfg): |
| from ac2.ror.comm import get_local_rank, get_world_rank, get_local_size, get_node_rank, get_world_size |
| cfg.dist_params.num_gpus_per_node = get_local_size() |
| cfg.dist_params.world_size = get_world_size() |
| cfg.dist_params.nnodes = (get_world_size()) // (get_local_size()) |
| cfg.dist_params.node_rank = get_node_rank() |
| cfg.dist_params.global_rank = get_world_rank() |
| cfg.dist_params.local_rank = get_local_rank() |
| os.environ["WORLD_SIZE"] = str(get_world_size()) |
|
|
|
|
| def _init_dist_slurm(cfg): |
| if 'NNODES' not in os.environ: |
| os.environ['NNODES'] = str(cfg.dist_params.nnodes) |
| if 'NODE_RANK' not in os.environ: |
| os.environ['NODE_RANK'] = str(cfg.dist_params.node_rank) |
|
|
| |
| num_gpus = torch.cuda.device_count() |
| world_size = int(os.environ['NNODES']) * num_gpus |
| os.environ['WORLD_SIZE'] = str(world_size) |
|
|
| |
| if 'MASTER_PORT' in os.environ: |
| master_port = str(os.environ['MASTER_PORT']) |
| else: |
| |
| |
| if _is_free_port(16500): |
| master_port = '16500' |
| else: |
| master_port = str(_find_free_port()) |
| os.environ['MASTER_PORT'] = master_port |
|
|
| |
| if 'MASTER_ADDR' in os.environ: |
| master_addr = str(os.environ['MASTER_PORT']) |
| |
| |
| else: |
| master_addr = '127.0.0.1' |
| os.environ['MASTER_ADDR'] = master_addr |
|
|
| |
| cfg.dist_params.dist_url = 'env://' |
| |
| cfg.dist_params.num_gpus_per_node = num_gpus |
| cfg.dist_params.world_size = world_size |
| cfg.dist_params.nnodes = int(os.environ['NNODES']) |
| cfg.dist_params.node_rank = int(os.environ['NODE_RANK']) |
| |
| |
| |
| |
|
|
| def get_func(func_name): |
| """ |
| Helper to return a function object by name. func_name must identify |
| a function in this module or the path to a function relative to the base |
| module. |
| @ func_name: function name. |
| """ |
| if func_name == '': |
| return None |
| try: |
| parts = func_name.split('.') |
| |
| if len(parts) == 1: |
| return globals()[parts[0]] |
| |
| module_name = '.'.join(parts[:-1]) |
| module = importlib.import_module(module_name) |
| return getattr(module, parts[-1]) |
| except: |
| raise RuntimeError(f'Failed to find function: {func_name}') |
|
|
| class Timer(object): |
| """A simple timer.""" |
|
|
| def __init__(self): |
| self.reset() |
|
|
| def tic(self): |
| |
| |
| self.start_time = time.time() |
|
|
| def toc(self, average=True): |
| self.diff = time.time() - self.start_time |
| self.total_time += self.diff |
| self.calls += 1 |
| self.average_time = self.total_time / self.calls |
| if average: |
| return self.average_time |
| else: |
| return self.diff |
|
|
| def reset(self): |
| self.total_time = 0. |
| self.calls = 0 |
| self.start_time = 0. |
| self.diff = 0. |
| self.average_time = 0. |
|
|
| class TrainingStats(object): |
| """Track vital training statistics.""" |
| def __init__(self, log_period, tensorboard_logger=None): |
| self.log_period = log_period |
| self.tblogger = tensorboard_logger |
| self.tb_ignored_keys = ['iter', 'eta', 'epoch', 'time'] |
| self.iter_timer = Timer() |
| |
| self.filter_size = log_period |
| def create_smoothed_value(): |
| return AverageMeter() |
| self.smoothed_losses = defaultdict(create_smoothed_value) |
| |
| |
|
|
|
|
| def IterTic(self): |
| self.iter_timer.tic() |
|
|
| def IterToc(self): |
| return self.iter_timer.toc(average=False) |
|
|
| def reset_iter_time(self): |
| self.iter_timer.reset() |
|
|
| def update_iter_stats(self, losses_dict): |
| """Update tracked iteration statistics.""" |
| for k, v in losses_dict.items(): |
| self.smoothed_losses[k].update(float(v), 1) |
|
|
| def log_iter_stats(self, cur_iter, optimizer, max_iters, val_err={}): |
| """Log the tracked statistics.""" |
| if (cur_iter % self.log_period == 0): |
| stats = self.get_stats(cur_iter, optimizer, max_iters, val_err) |
| log_stats(stats) |
| if self.tblogger: |
| self.tb_log_stats(stats, cur_iter) |
| for k, v in self.smoothed_losses.items(): |
| v.reset() |
|
|
| def tb_log_stats(self, stats, cur_iter): |
| """Log the tracked statistics to tensorboard""" |
| for k in stats: |
| |
| if k not in self.tb_ignored_keys: |
| v = stats[k] |
| if isinstance(v, dict): |
| self.tb_log_stats(v, cur_iter) |
| else: |
| self.tblogger.add_scalar(k, v, cur_iter) |
|
|
|
|
| def get_stats(self, cur_iter, optimizer, max_iters, val_err = {}): |
| eta_seconds = self.iter_timer.average_time * (max_iters - cur_iter) |
|
|
| eta = str(datetime.timedelta(seconds=int(eta_seconds))) |
| stats = OrderedDict( |
| iter=cur_iter, |
| time=self.iter_timer.average_time, |
| eta=eta, |
| ) |
| optimizer_state_dict = optimizer.state_dict() |
| lr = {} |
| for i in range(len(optimizer_state_dict['param_groups'])): |
| lr_name = 'group%d_lr' % i |
| lr[lr_name] = optimizer_state_dict['param_groups'][i]['lr'] |
|
|
| stats['lr'] = OrderedDict(lr) |
| for k, v in self.smoothed_losses.items(): |
| stats[k] = v.avg |
|
|
| stats['val_err'] = OrderedDict(val_err) |
| stats['max_iters'] = max_iters |
| return stats |
|
|
|
|
| def reduce_dict(input_dict, average=True): |
| """ |
| Reduce the values in the dictionary from all processes so that process with rank |
| 0 has the reduced results. |
| Args: |
| @input_dict (dict): inputs to be reduced. All the values must be scalar CUDA Tensor. |
| @average (bool): whether to do average or sum |
| Returns: |
| a dict with the same keys as input_dict, after reduction. |
| """ |
| world_size = get_world_size() |
| if world_size < 2: |
| return input_dict |
| with torch.no_grad(): |
| names = [] |
| values = [] |
| |
| for k in sorted(input_dict.keys()): |
| names.append(k) |
| values.append(input_dict[k]) |
| values = torch.stack(values, dim=0) |
| dist.reduce(values, dst=0) |
| if dist.get_rank() == 0 and average: |
| |
| |
| values /= world_size |
| reduced_dict = {k: v for k, v in zip(names, values)} |
| return reduced_dict |
|
|
|
|
| def log_stats(stats): |
| logger = logging.getLogger() |
| """Log training statistics to terminal""" |
| lines = "[Step %d/%d]\n" % ( |
| stats['iter'], stats['max_iters']) |
|
|
| lines += "\t\tloss: %.3f, time: %.6f, eta: %s\n" % ( |
| stats['total_loss'], stats['time'], stats['eta']) |
|
|
| |
| lines += "\t\t" |
| for k, v in stats.items(): |
| if 'loss' in k.lower() and 'total_loss' not in k.lower(): |
| lines += "%s: %.3f" % (k, v) + ", " |
| lines = lines[:-3] |
| lines += '\n' |
|
|
| |
| lines += "\t\tlast val err:" + ", ".join("%s: %.6f" % (k, v) for k, v in stats['val_err'].items()) + ", " |
| lines += '\n' |
|
|
| |
| lines += "\t\t" + ", ".join("%s: %.8f" % (k, v) for k, v in stats['lr'].items()) |
| lines += '\n' |
| logger.info(lines[:-1]) |
|
|
|
|