| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| """Communications utilities.""" |
|
|
|
|
| import torch |
|
|
| from megatron import mpu |
|
|
|
|
|
|
| |
| def recv_from_prev_pipeline_rank_(recv_buffer=None): |
| """Receive from previous pipeline stage and update the |
| input buffer inplace.""" |
| if not mpu.is_pipeline_first_stage(): |
| assert recv_buffer is not None |
| recv_prev_op = torch.distributed.P2POp( |
| torch.distributed.irecv, recv_buffer, |
| mpu.get_pipeline_model_parallel_prev_rank()) |
| reqs = torch.distributed.batch_isend_irecv([recv_prev_op]) |
| for req in reqs: |
| req.wait() |
| |
| torch.cuda.synchronize() |
|
|
|
|
|
|
| |
| def send_to_next_pipeline_rank(tensor=None): |
| """Send output to the next pipeline stage.""" |
| if not mpu.is_pipeline_last_stage(): |
| assert tensor is not None |
| send_next_op = torch.distributed.P2POp( |
| torch.distributed.isend, tensor, |
| mpu.get_pipeline_model_parallel_next_rank()) |
| reqs = torch.distributed.batch_isend_irecv([send_next_op]) |
| for req in reqs: |
| req.wait() |
| |
| torch.cuda.synchronize() |
|
|
|
|
|
|
| def _is_cuda(tensor): |
| """Check if a tensor is not none and is cuda.""" |
| assert tensor is not None |
| assert tensor.is_cuda |
|
|
|
|
|
|
| def _is_cuda_contiguous(tensor): |
| """Check if a tensor is not none, is cuda, and is contiguous.""" |
| _is_cuda(tensor) |
| assert tensor.is_contiguous() |
|
|
|
|
|
|
| def broadcast_from_last_pipeline_stage(size, dtype, tensor=None): |
| """Broadcast a tensor from last pipeline stage to all ranks.""" |
|
|
| is_last_stage = mpu.is_pipeline_last_stage() |
| |
| |
| if mpu.is_pipeline_first_stage() and is_last_stage: |
| return tensor |
|
|
| if is_last_stage: |
| _is_cuda_contiguous(tensor) |
| else: |
| tensor = torch.empty(size, |
| dtype=dtype, |
| device=torch.cuda.current_device()) |
| |
| src = mpu.get_pipeline_model_parallel_last_rank() |
| group = mpu.get_pipeline_model_parallel_group() |
| torch.distributed.broadcast(tensor, src, group) |
|
|
| return tensor |
|
|
|
|
|
|
| def broadcast_from_last_to_first_pipeline_stage(size, dtype, tensor=None): |
| """Broadcast tensor values from last stage into the first stage.""" |
|
|
| is_last_stage = mpu.is_pipeline_last_stage() |
| is_first_stage = mpu.is_pipeline_first_stage() |
| |
| |
| if is_first_stage and is_last_stage: |
| return tensor |
| |
| if is_last_stage or is_first_stage: |
| if is_last_stage: |
| _is_cuda_contiguous(tensor) |
| else: |
| tensor = torch.empty(size, |
| dtype=dtype, |
| device=torch.cuda.current_device()) |
| src = mpu.get_pipeline_model_parallel_last_rank() |
| group = mpu.get_embedding_group() |
| |
| torch.distributed.broadcast(tensor, src, group) |
| else: |
| tensor = None |
|
|
| return tensor |
|
|
|
|
|
|
| def copy_from_last_to_first_pipeline_stage(size, dtype, tensor=None): |
| """Copy tensor values from last stage into the first stage. |
| Note that the input tensor is updated in place.""" |
|
|
| is_last_stage = mpu.is_pipeline_last_stage() |
| is_first_stage = mpu.is_pipeline_first_stage() |
| |
| |
| if is_first_stage and is_last_stage: |
| return |
| |
| if is_last_stage or is_first_stage: |
| _is_cuda(tensor) |
| is_contiguous = tensor.is_contiguous() |
| src = mpu.get_pipeline_model_parallel_last_rank() |
| group = mpu.get_embedding_group() |
| if is_contiguous: |
| tensor_ = tensor |
| else: |
| if is_last_stage: |
| tensor_ = tensor.contiguous() |
| else: |
| tensor_ = torch.empty(size, |
| dtype=dtype, |
| device=torch.cuda.current_device()) |
| |
| torch.distributed.broadcast(tensor_, src, group) |
| |
| if is_first_stage and not is_contiguous: |
| tensor[...] = tensor_ |
|
|
|
|
|
|
| def broadcast_tensor(size, dtype, tensor=None, rank=0): |
| """ Given size and type of a tensor on all ranks and the tensor value |
| only on a specific rank, broadcast from that rank to all other ranks. |
| """ |
|
|
| if torch.distributed.get_rank() == rank: |
| _is_cuda_contiguous(tensor) |
| else: |
| tensor = torch.empty(size, |
| dtype=dtype, |
| device=torch.cuda.current_device()) |
|
|
| torch.distributed.broadcast(tensor, rank) |
|
|
| return tensor |
|
|
|
|
|
|
| def broadcast_list(size, dtype, list_values=None, rank=0): |
| """Broadcast a list of values with a given type.""" |
|
|
| tensor = None |
| if torch.distributed.get_rank() == rank: |
| tensor = torch.tensor(list_values, dtype=dtype, |
| device=torch.cuda.current_device()) |
|
|
| return broadcast_tensor(size, dtype, tensor=tensor, rank=rank) |
|
|
|
|
|
|
| def broadcast_int_list(size, int_list=None, rank=0): |
| """Broadcast a list of interger values.""" |
|
|
| return broadcast_list(size, torch.int64, list_values=int_list, rank=rank) |
|
|
|
|
|
|
| def broadcast_float_list(size, float_list=None, rank=0): |
| """Broadcast a list of float values.""" |
|
|
| return broadcast_list(size, torch.float32, list_values=float_list, |
| rank=rank) |
|
|