vidfom's picture
Upload folder using huggingface_hub
31112ad verified
import torch
def p2p_communicate(
rank, send_tensor, send_dst, recv_tensor, recv_src, cp_group, batch_p2p_comm
):
"""Point-to-point communications of KV and dKV in Attention with context parallelism"""
send_recv_ops = []
if batch_p2p_comm: # int(os.getenv("NVTE_BATCH_MHA_P2P_COMM", "0")) or (cp_size == 2) why set this?
if rank % 2 == 0:
send_op = torch.distributed.P2POp(
torch.distributed.isend, send_tensor, send_dst, cp_group
)
recv_op = torch.distributed.P2POp(
torch.distributed.irecv, recv_tensor, recv_src, cp_group
)
send_recv_ops.append(send_op)
send_recv_ops.append(recv_op)
else:
recv_op = torch.distributed.P2POp(
torch.distributed.irecv, recv_tensor, recv_src, cp_group
)
send_op = torch.distributed.P2POp(
torch.distributed.isend, send_tensor, send_dst, cp_group
)
send_recv_ops.append(recv_op)
send_recv_ops.append(send_op)
send_recv_reqs = torch.distributed.batch_isend_irecv(send_recv_ops)
else:
if rank % 2 == 0:
send_op = torch.distributed.isend(send_tensor, send_dst, cp_group)
recv_op = torch.distributed.irecv(recv_tensor, recv_src, cp_group)
send_recv_ops.append(send_op)
send_recv_ops.append(recv_op)
else:
recv_op = torch.distributed.irecv(recv_tensor, recv_src, cp_group)
send_op = torch.distributed.isend(send_tensor, send_dst, cp_group)
send_recv_ops.append(recv_op)
send_recv_ops.append(send_op)
send_recv_reqs = send_recv_ops
return send_recv_reqs