|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch
|
|
|
from megatron.core import parallel_state as mpu
|
|
|
from megatron.core.packed_seq_params import PackedSeqParams
|
|
|
|
|
|
|
|
|
def preprocess_packed_seqs(input_ids: torch.Tensor, attention_mask: torch.Tensor, pre_process: bool = True) -> tuple[torch.Tensor, PackedSeqParams]:
|
|
|
"""
|
|
|
Preprocess packed sequences
|
|
|
CP splits sequence into CP*2 chunks, and each GPU gets 2 chunks (GPU0 gets first and last chunks, GPU1 gets second and second last chunks, and so on), this is for load balancing with causal masking.
|
|
|
See https://github.com/NVIDIA/TransformerEngine/issues/1368
|
|
|
"""
|
|
|
batch_size = input_ids.shape[0]
|
|
|
|
|
|
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
|
|
|
tp_size = mpu.get_tensor_model_parallel_world_size()
|
|
|
cp_size = mpu.get_context_parallel_world_size()
|
|
|
cp_rank = mpu.get_context_parallel_rank()
|
|
|
align_size = tp_size * cp_size * 2 if cp_size > 1 else tp_size
|
|
|
|
|
|
pad_size = (align_size - seqlens_in_batch % align_size) % align_size
|
|
|
seqlens_in_batch_padded = seqlens_in_batch + pad_size
|
|
|
cu_seqlens = torch.zeros(batch_size + 1, dtype=torch.int32, device=input_ids.device)
|
|
|
cu_seqlens[1:] = torch.cumsum(seqlens_in_batch, dim=0)
|
|
|
cu_seqlens_padded = torch.zeros(batch_size + 1, dtype=torch.int32, device=input_ids.device)
|
|
|
cu_seqlens_padded[1:] = torch.cumsum(seqlens_in_batch_padded, dim=0)
|
|
|
max_seqlen_in_batch = seqlens_in_batch_padded.max().item()
|
|
|
|
|
|
shape = list(input_ids.shape[1:])
|
|
|
shape[0] = seqlens_in_batch_padded.sum().item() // cp_size
|
|
|
if pre_process:
|
|
|
input_ids_rmpad = torch.zeros(shape, dtype=input_ids.dtype, device=input_ids.device)
|
|
|
for i in range(batch_size):
|
|
|
if cp_size <= 1:
|
|
|
seqlen = seqlens_in_batch[i]
|
|
|
input_ids_rmpad[cu_seqlens_padded[i] : cu_seqlens_padded[i] + seqlen] = input_ids[i, attention_mask[i]]
|
|
|
continue
|
|
|
seqlen = seqlens_in_batch_padded[i] // cp_size
|
|
|
half_seqlen = seqlen // 2
|
|
|
start_idx = cu_seqlens_padded[i] // cp_size
|
|
|
|
|
|
d = input_ids[i, attention_mask[i]]
|
|
|
input_ids_rmpad[start_idx : start_idx + half_seqlen] = d[half_seqlen * cp_rank : half_seqlen * (cp_rank + 1)]
|
|
|
|
|
|
remain_start = seqlens_in_batch_padded[i] - half_seqlen * (cp_rank + 1)
|
|
|
remain_end = seqlens_in_batch_padded[i] - half_seqlen * cp_rank
|
|
|
remain_end = min(remain_end, d.shape[0])
|
|
|
remain_len = remain_end - remain_start
|
|
|
if remain_len > 0:
|
|
|
input_ids_rmpad[start_idx + half_seqlen : start_idx + half_seqlen + remain_len] = d[remain_start:remain_end]
|
|
|
|
|
|
packed_seq_params = PackedSeqParams(
|
|
|
qkv_format="thd",
|
|
|
cu_seqlens_q=cu_seqlens_padded,
|
|
|
max_seqlen_q=max_seqlen_in_batch,
|
|
|
cu_seqlens_kv=cu_seqlens_padded,
|
|
|
max_seqlen_kv=max_seqlen_in_batch,
|
|
|
cu_seqlens_q_padded=cu_seqlens_padded,
|
|
|
cu_seqlens_kv_padded=cu_seqlens_padded,
|
|
|
)
|
|
|
if pre_process:
|
|
|
return input_ids_rmpad.unsqueeze(0), packed_seq_params
|
|
|
else:
|
|
|
return input_ids, packed_seq_params
|
|
|
|
|
|
|
|
|
def postprocess_packed_seqs(
|
|
|
output: torch.Tensor,
|
|
|
packed_seq_params: PackedSeqParams,
|
|
|
attention_mask: torch.Tensor,
|
|
|
batch_size: int,
|
|
|
seq_len: int,
|
|
|
post_process: bool = True,
|
|
|
) -> torch.Tensor:
|
|
|
"""
|
|
|
Postprocess packed sequences
|
|
|
"""
|
|
|
if not post_process:
|
|
|
return output
|
|
|
shape = [batch_size, seq_len] + list(output.shape[2:])
|
|
|
output_new = torch.zeros(shape, dtype=output.dtype, device=output.device)
|
|
|
|
|
|
cp_size = mpu.get_context_parallel_world_size()
|
|
|
|
|
|
if cp_size > 1:
|
|
|
|
|
|
|
|
|
output_list = [torch.empty_like(output) for _ in range(cp_size)]
|
|
|
torch.distributed.all_gather(output_list, output.detach(), group=mpu.get_context_parallel_group())
|
|
|
output_list[mpu.get_context_parallel_rank()] = output
|
|
|
else:
|
|
|
output_list = [output]
|
|
|
for i in range(batch_size):
|
|
|
if cp_size <= 1:
|
|
|
s = attention_mask[i].sum().item()
|
|
|
output_new[i, attention_mask[i]] = output[0][packed_seq_params.cu_seqlens_q_padded[i] : packed_seq_params.cu_seqlens_q_padded[i] + s]
|
|
|
continue
|
|
|
s_len_padded_chunk = (packed_seq_params.cu_seqlens_q_padded[i + 1] - packed_seq_params.cu_seqlens_q_padded[i]) // cp_size
|
|
|
half_seqlen = s_len_padded_chunk // 2
|
|
|
s_len = attention_mask[i].sum().item()
|
|
|
s_len_padded = s_len_padded_chunk * cp_size
|
|
|
tmp = torch.empty(s_len_padded, *output.shape[2:], device=output.device)
|
|
|
for j in range(cp_size):
|
|
|
o = output_list[j][0]
|
|
|
|
|
|
packed_start_idx = packed_seq_params.cu_seqlens_q_padded[i] // cp_size
|
|
|
o0, o1 = (
|
|
|
o[packed_start_idx : packed_start_idx + half_seqlen],
|
|
|
o[packed_start_idx + half_seqlen : packed_start_idx + s_len_padded_chunk],
|
|
|
)
|
|
|
tmp[j * half_seqlen : (j + 1) * half_seqlen] = o0
|
|
|
tmp[s_len_padded - (j + 1) * half_seqlen : s_len_padded - j * half_seqlen] = o1
|
|
|
output_new[i, attention_mask[i]] = tmp[:s_len]
|
|
|
|
|
|
return output_new
|
|
|
|
|
|
|
|
|
def remove_left_padding(
|
|
|
input_ids: torch.Tensor,
|
|
|
attention_mask: torch.Tensor,
|
|
|
position_ids: torch.Tensor,
|
|
|
sequence_parallel: bool = False,
|
|
|
pre_process: bool = True,
|
|
|
):
|
|
|
"""
|
|
|
Remove left padding from input_ids, attention_mask and position_ids
|
|
|
return new_input_ids, new_attention_mask, new_position_ids
|
|
|
"""
|
|
|
assert attention_mask.ndim == 2
|
|
|
assert position_ids.ndim == 2
|
|
|
cp_size = mpu.get_context_parallel_world_size()
|
|
|
assert cp_size == 1, "Context parallel size without seq_pack is not supported"
|
|
|
batch_size = input_ids.shape[0]
|
|
|
shape = list(input_ids.shape)
|
|
|
seq_lens = attention_mask.sum(dim=1)
|
|
|
seq_len = seq_lens.max().item()
|
|
|
if sequence_parallel:
|
|
|
sp_world_size = mpu.get_tensor_model_parallel_world_size()
|
|
|
pad_size = (sp_world_size - seq_len % sp_world_size) % sp_world_size
|
|
|
seq_len = seq_len + pad_size
|
|
|
shape[1] = seq_len
|
|
|
if pre_process:
|
|
|
new_input_ids = torch.zeros(dtype=input_ids.dtype, device=input_ids.device, size=shape)
|
|
|
new_attention_mask = torch.zeros(dtype=attention_mask.dtype, device=attention_mask.device, size=(batch_size, seq_len))
|
|
|
new_position_ids = torch.zeros(dtype=position_ids.dtype, device=position_ids.device, size=(batch_size, seq_len))
|
|
|
for i in range(batch_size):
|
|
|
if pre_process:
|
|
|
new_input_ids[i, : seq_lens[i]] = input_ids[i, attention_mask[i]]
|
|
|
new_attention_mask[i, : seq_lens[i]] = attention_mask[i, attention_mask[i]]
|
|
|
new_position_ids[i, : seq_lens[i]] = position_ids[i, attention_mask[i]]
|
|
|
if pre_process:
|
|
|
return new_input_ids, new_attention_mask, new_position_ids
|
|
|
else:
|
|
|
return input_ids, new_attention_mask, new_position_ids
|
|
|
|
|
|
|
|
|
def recover_left_padding(
|
|
|
result,
|
|
|
attention_mask: torch.Tensor,
|
|
|
original_attention_mask: torch.Tensor,
|
|
|
origin_seqlen: int,
|
|
|
post_process: bool = True,
|
|
|
):
|
|
|
"""
|
|
|
Recover left padding from result
|
|
|
return result
|
|
|
"""
|
|
|
if not post_process:
|
|
|
return result
|
|
|
shape = list(result.shape)
|
|
|
batch_size = shape[0]
|
|
|
shape[1] = origin_seqlen
|
|
|
new_result = torch.zeros(dtype=result.dtype, device=result.device, size=shape)
|
|
|
for i in range(batch_size):
|
|
|
new_result[i, original_attention_mask[i]] = result[i, attention_mask[i]]
|
|
|
return new_result
|
|
|
|