| | import torch |
| |
|
| |
|
| | def p2p_communicate( |
| | rank, send_tensor, send_dst, recv_tensor, recv_src, cp_group, batch_p2p_comm |
| | ): |
| | """Point-to-point communications of KV and dKV in Attention with context parallelism""" |
| | send_recv_ops = [] |
| |
|
| | if batch_p2p_comm: |
| | if rank % 2 == 0: |
| | send_op = torch.distributed.P2POp( |
| | torch.distributed.isend, send_tensor, send_dst, cp_group |
| | ) |
| | recv_op = torch.distributed.P2POp( |
| | torch.distributed.irecv, recv_tensor, recv_src, cp_group |
| | ) |
| | send_recv_ops.append(send_op) |
| | send_recv_ops.append(recv_op) |
| | else: |
| | recv_op = torch.distributed.P2POp( |
| | torch.distributed.irecv, recv_tensor, recv_src, cp_group |
| | ) |
| | send_op = torch.distributed.P2POp( |
| | torch.distributed.isend, send_tensor, send_dst, cp_group |
| | ) |
| | send_recv_ops.append(recv_op) |
| | send_recv_ops.append(send_op) |
| | send_recv_reqs = torch.distributed.batch_isend_irecv(send_recv_ops) |
| | else: |
| | if rank % 2 == 0: |
| | send_op = torch.distributed.isend(send_tensor, send_dst, cp_group) |
| | recv_op = torch.distributed.irecv(recv_tensor, recv_src, cp_group) |
| | send_recv_ops.append(send_op) |
| | send_recv_ops.append(recv_op) |
| | else: |
| | recv_op = torch.distributed.irecv(recv_tensor, recv_src, cp_group) |
| | send_op = torch.distributed.isend(send_tensor, send_dst, cp_group) |
| | send_recv_ops.append(recv_op) |
| | send_recv_ops.append(send_op) |
| | send_recv_reqs = send_recv_ops |
| |
|
| | return send_recv_reqs |
| |
|