| | import datetime |
| | import functools |
| | import os |
| | import sys |
| | from typing import List |
| | from typing import Union |
| |
|
| | import pytz |
| | import torch |
| | import torch.distributed as tdist |
| | import torch.multiprocessing as mp |
| |
|
| | __rank, __local_rank, __world_size, __device = 0, 0, 1, 'cuda' if torch.cuda.is_available() else 'cpu' |
| | __rank_str_zfill = '0' |
| | __initialized = False |
| |
|
| |
|
| | def initialized(): |
| | return __initialized |
| |
|
| |
|
| | def __initialize(fork=False, backend='nccl', gpu_id_if_not_distibuted=0, timeout_minutes=30): |
| | global __device |
| | if not torch.cuda.is_available(): |
| | print(f'[dist initialize] cuda is not available, use cpu instead', file=sys.stderr) |
| | return |
| | elif 'RANK' not in os.environ: |
| | torch.cuda.set_device(gpu_id_if_not_distibuted) |
| | __device = torch.empty(1).cuda().device |
| | print(f'[dist initialize] env variable "RANK" is not set, use {__device} as the device', file=sys.stderr) |
| | return |
| | |
| | global_rank, num_gpus = int(os.environ['RANK']), torch.cuda.device_count() |
| | local_rank = global_rank % num_gpus |
| | torch.cuda.set_device(local_rank) |
| | |
| | |
| | if mp.get_start_method(allow_none=True) is None: |
| | method = 'fork' if fork else 'spawn' |
| | print(f'[dist initialize] mp method={method}') |
| | mp.set_start_method(method) |
| | tdist.init_process_group(backend=backend, timeout=datetime.timedelta(seconds=timeout_minutes * 60)) |
| | |
| | global __rank, __local_rank, __world_size, __initialized, __rank_str_zfill |
| | __local_rank = local_rank |
| | __rank, __world_size = tdist.get_rank(), tdist.get_world_size() |
| | __rank_str_zfill = str(__rank).zfill(len(str(__world_size))) |
| | __device = torch.empty(1).cuda().device |
| | __initialized = True |
| | |
| | assert tdist.is_initialized(), 'torch.distributed is not initialized!' |
| | print(f'[lrk={get_local_rank()}, rk={get_rank()}]') |
| |
|
| |
|
| | def get_rank(): |
| | return __rank |
| |
|
| |
|
| | def get_rank_str_zfill(): |
| | return __rank_str_zfill |
| |
|
| |
|
| | def get_local_rank(): |
| | return __local_rank |
| |
|
| |
|
| | def get_world_size(): |
| | return __world_size |
| |
|
| |
|
| | def get_device(): |
| | return __device |
| |
|
| |
|
| | def set_gpu_id(gpu_id: int): |
| | if gpu_id is None: return |
| | global __device |
| | if isinstance(gpu_id, (str, int)): |
| | torch.cuda.set_device(int(gpu_id)) |
| | __device = torch.empty(1).cuda().device |
| | else: |
| | raise NotImplementedError |
| |
|
| |
|
| | def is_master(): |
| | return __rank == 0 |
| |
|
| |
|
| | def is_local_master(): |
| | return __local_rank == 0 |
| |
|
| |
|
| | def new_group(ranks: List[int]): |
| | if __initialized: |
| | return tdist.new_group(ranks=ranks) |
| | return None |
| |
|
| |
|
| | def new_local_machine_group(): |
| | if __initialized: |
| | cur_subgroup, subgroups = tdist.new_subgroups() |
| | return cur_subgroup |
| | return None |
| |
|
| |
|
| | def barrier(): |
| | if __initialized: |
| | tdist.barrier() |
| |
|
| |
|
| | def allreduce(t: torch.Tensor, async_op=False): |
| | if __initialized: |
| | if not t.is_cuda: |
| | cu = t.detach().cuda() |
| | ret = tdist.all_reduce(cu, async_op=async_op) |
| | t.copy_(cu.cpu()) |
| | else: |
| | ret = tdist.all_reduce(t, async_op=async_op) |
| | return ret |
| | return None |
| |
|
| |
|
| | def allgather(t: torch.Tensor, cat=True) -> Union[List[torch.Tensor], torch.Tensor]: |
| | if __initialized: |
| | if not t.is_cuda: |
| | t = t.cuda() |
| | ls = [torch.empty_like(t) for _ in range(__world_size)] |
| | tdist.all_gather(ls, t) |
| | else: |
| | ls = [t] |
| | if cat: |
| | ls = torch.cat(ls, dim=0) |
| | return ls |
| |
|
| |
|
| | def allgather_diff_shape(t: torch.Tensor, cat=True) -> Union[List[torch.Tensor], torch.Tensor]: |
| | if __initialized: |
| | if not t.is_cuda: |
| | t = t.cuda() |
| | |
| | t_size = torch.tensor(t.size(), device=t.device) |
| | ls_size = [torch.empty_like(t_size) for _ in range(__world_size)] |
| | tdist.all_gather(ls_size, t_size) |
| | |
| | max_B = max(size[0].item() for size in ls_size) |
| | pad = max_B - t_size[0].item() |
| | if pad: |
| | pad_size = (pad, *t.size()[1:]) |
| | t = torch.cat((t, t.new_empty(pad_size)), dim=0) |
| | |
| | ls_padded = [torch.empty_like(t) for _ in range(__world_size)] |
| | tdist.all_gather(ls_padded, t) |
| | ls = [] |
| | for t, size in zip(ls_padded, ls_size): |
| | ls.append(t[:size[0].item()]) |
| | else: |
| | ls = [t] |
| | if cat: |
| | ls = torch.cat(ls, dim=0) |
| | return ls |
| |
|
| |
|
| | def broadcast(t: torch.Tensor, src_rank) -> None: |
| | if __initialized: |
| | if not t.is_cuda: |
| | cu = t.detach().cuda() |
| | tdist.broadcast(cu, src=src_rank) |
| | t.copy_(cu.cpu()) |
| | else: |
| | tdist.broadcast(t, src=src_rank) |
| |
|
| |
|
| | def dist_fmt_vals(val: float, fmt: Union[str, None] = '%.2f') -> Union[torch.Tensor, List]: |
| | if not initialized(): |
| | return torch.tensor([val]) if fmt is None else [fmt % val] |
| | |
| | ts = torch.zeros(__world_size) |
| | ts[__rank] = val |
| | allreduce(ts) |
| | if fmt is None: |
| | return ts |
| | return [fmt % v for v in ts.cpu().numpy().tolist()] |
| |
|
| |
|
| | def master_only(func): |
| | @functools.wraps(func) |
| | def wrapper(*args, **kwargs): |
| | force = kwargs.pop('force', False) |
| | if force or is_master(): |
| | ret = func(*args, **kwargs) |
| | else: |
| | ret = None |
| | barrier() |
| | return ret |
| | return wrapper |
| |
|
| |
|
| | def local_master_only(func): |
| | @functools.wraps(func) |
| | def wrapper(*args, **kwargs): |
| | force = kwargs.pop('force', False) |
| | if force or is_local_master(): |
| | ret = func(*args, **kwargs) |
| | else: |
| | ret = None |
| | barrier() |
| | return ret |
| | return wrapper |
| |
|
| |
|
| | def for_visualize(func): |
| | @functools.wraps(func) |
| | def wrapper(*args, **kwargs): |
| | if is_master(): |
| | |
| | ret = func(*args, **kwargs) |
| | else: |
| | ret = None |
| | return ret |
| | return wrapper |
| |
|
| |
|
| | def finalize(): |
| | if __initialized: |
| | tdist.destroy_process_group() |
| |
|
| |
|
| | def init_distributed_mode(local_out_path, only_sync_master=False, timeout_minutes=30): |
| | try: |
| | __initialize(fork=False, timeout_minutes=timeout_minutes) |
| | barrier() |
| | except RuntimeError as e: |
| | print(f'{"!"*80} dist init error (NCCL Error?), stopping training! {"!"*80}', flush=True) |
| | raise e |
| | |
| | if local_out_path is not None: os.makedirs(local_out_path, exist_ok=True) |
| | _change_builtin_print(is_local_master()) |
| | if (is_master() if only_sync_master else is_local_master()) and local_out_path is not None and len(local_out_path): |
| | sys.stdout, sys.stderr = BackupStreamToFile(local_out_path, for_stdout=True), BackupStreamToFile(local_out_path, for_stdout=False) |
| |
|
| |
|
| | def _change_builtin_print(is_master): |
| | import builtins as __builtin__ |
| | |
| | builtin_print = __builtin__.print |
| | if type(builtin_print) != type(open): |
| | return |
| | |
| | def prt(*args, **kwargs): |
| | force = kwargs.pop('force', False) |
| | clean = kwargs.pop('clean', False) |
| | deeper = kwargs.pop('deeper', False) |
| | if is_master or force: |
| | if not clean: |
| | f_back = sys._getframe().f_back |
| | if deeper and f_back.f_back is not None: |
| | f_back = f_back.f_back |
| | file_desc = f'{f_back.f_code.co_filename:24s}'[-24:] |
| | time_str = datetime.datetime.now(tz=pytz.timezone('Asia/Shanghai')).strftime('[%m-%d %H:%M:%S]') |
| | builtin_print(f'{time_str} ({file_desc}, line{f_back.f_lineno:-4d})=>', *args, **kwargs) |
| | else: |
| | builtin_print(*args, **kwargs) |
| | |
| | __builtin__.print = prt |
| |
|
| |
|
| | class BackupStreamToFile(object): |
| | def __init__(self, local_output_dir, for_stdout=True): |
| | self.for_stdout = for_stdout |
| | self.terminal_stream = sys.stdout if for_stdout else sys.stderr |
| | fname = os.path.join(local_output_dir, 'backup1_stdout.txt' if for_stdout else 'backup2_stderr.txt') |
| | existing = os.path.exists(fname) |
| | self.file_stream = open(fname, 'a') |
| | if existing: |
| | time_str = datetime.datetime.now(tz=pytz.timezone('Asia/Shanghai')).strftime('[%m-%d %H:%M:%S]') |
| | self.file_stream.write('\n'*7 + '='*55 + f' RESTART {time_str} ' + '='*55 + '\n') |
| | self.file_stream.flush() |
| | self.enabled = True |
| | |
| | def write(self, message): |
| | self.terminal_stream.write(message) |
| | self.file_stream.write(message) |
| | |
| | def flush(self): |
| | self.terminal_stream.flush() |
| | self.file_stream.flush() |
| | |
| | def close(self): |
| | if not self.enabled: |
| | return |
| | self.enabled = False |
| | self.file_stream.flush() |
| | self.file_stream.close() |
| | if self.for_stdout: |
| | sys.stdout = self.terminal_stream |
| | sys.stdout.flush() |
| | else: |
| | sys.stderr = self.terminal_stream |
| | sys.stderr.flush() |
| | |
| | def __del__(self): |
| | self.close() |