| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import logging |
| import math |
| import os |
| from typing import Optional |
|
|
| import torch |
| from megatron.core import parallel_state as mpu |
| from megatron.core.packed_seq_params import PackedSeqParams |
|
|
| from verl.utils.model import CausalLMOutputForPPO |
|
|
| logger = logging.getLogger(__file__) |
| logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) |
|
|
|
|
| def _compute_fp8_thd_align_size(align_size: int) -> tuple[int, int]: |
| """Compute FP8 alignment sizes for thd-format sequences. |
| |
| For FP8 block quantization, each sequence must be padded to a multiple of |
| lcm(16, align_size), and the total padded length must be divisible by |
| (align_size * 128) for TransformerEngine compatibility. |
| |
| Returns (per_seq_align_size, total_align_size). |
| """ |
| return math.lcm(16, align_size), align_size * 128 |
|
|
|
|
| def preprocess_packed_seqs( |
| input_ids: torch.Tensor, attention_mask: torch.Tensor, pre_process: bool = True, use_fp8_padding: bool = False |
| ) -> tuple[torch.Tensor, PackedSeqParams]: |
| """ |
| Preprocess packed sequences |
| CP splits sequence into CP*2 chunks, and each GPU gets 2 chunks (GPU0 gets first and last chunks, GPU1 |
| gets second and second last chunks, and so on), this is for load balancing with causal masking. |
| See https://github.com/NVIDIA/TransformerEngine/issues/1368 |
| """ |
| batch_size = input_ids.shape[0] |
|
|
| seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) |
| tp_size = mpu.get_tensor_model_parallel_world_size() |
| cp_size = mpu.get_context_parallel_world_size() |
| cp_rank = mpu.get_context_parallel_rank() |
| align_size = tp_size * cp_size * 2 if cp_size > 1 else tp_size |
| if use_fp8_padding: |
| per_seq_align, total_align = _compute_fp8_thd_align_size(align_size) |
| align_size = per_seq_align |
|
|
| pad_size = (align_size - seqlens_in_batch % align_size) % align_size |
| seqlens_in_batch_padded = seqlens_in_batch + pad_size |
|
|
| cu_seqlens = torch.zeros(batch_size + 1, dtype=torch.int32, device=input_ids.device) |
| cu_seqlens[1:] = torch.cumsum(seqlens_in_batch, dim=0) |
| cu_seqlens_padded = torch.zeros(batch_size + 1, dtype=torch.int32, device=input_ids.device) |
| cu_seqlens_padded[1:] = torch.cumsum(seqlens_in_batch_padded, dim=0) |
|
|
| if use_fp8_padding: |
| pad_size_last = (total_align - cu_seqlens_padded[-1] % total_align) % total_align |
| cu_seqlens_padded[-1] += pad_size_last |
| seqlens_in_batch_padded[-1] += pad_size_last |
|
|
| |
| |
| |
| |
| seqlens_in_batch_cpu: list[int] = seqlens_in_batch.tolist() |
| seqlens_in_batch_padded_cpu: list[int] = seqlens_in_batch_padded.tolist() |
| cu_seqlens_padded_cpu: list[int] = cu_seqlens_padded.tolist() |
|
|
| |
| max_seqlen_in_batch = max(seqlens_in_batch_padded_cpu) |
|
|
| shape = list(input_ids.shape[1:]) |
| shape[0] = sum(seqlens_in_batch_padded_cpu) // cp_size |
| if pre_process: |
| input_ids_rmpad = torch.zeros(shape, dtype=input_ids.dtype, device=input_ids.device) |
| for i in range(batch_size): |
| |
| if cp_size <= 1: |
| seqlen = seqlens_in_batch_cpu[i] |
| start_idx = cu_seqlens_padded_cpu[i] |
| input_ids_rmpad[start_idx : start_idx + seqlen] = input_ids[i, attention_mask[i]] |
| continue |
|
|
| seqlen_padded_i = seqlens_in_batch_padded_cpu[i] |
| seqlen = seqlen_padded_i // cp_size |
| half_seqlen = seqlen // 2 |
| start_idx = cu_seqlens_padded_cpu[i] // cp_size |
| |
| d = input_ids[i, attention_mask[i]] |
| input_ids_rmpad[start_idx : start_idx + half_seqlen] = d[ |
| half_seqlen * cp_rank : half_seqlen * (cp_rank + 1) |
| ] |
|
|
| remain_start = seqlen_padded_i - half_seqlen * (cp_rank + 1) |
| remain_end = seqlen_padded_i - half_seqlen * cp_rank |
| remain_end = min(remain_end, d.shape[0]) |
| remain_len = remain_end - remain_start |
| if remain_len > 0: |
| input_ids_rmpad[start_idx + half_seqlen : start_idx + half_seqlen + remain_len] = d[ |
| remain_start:remain_end |
| ] |
|
|
| packed_seq_params = PackedSeqParams( |
| qkv_format="thd", |
| cu_seqlens_q=cu_seqlens_padded, |
| max_seqlen_q=max_seqlen_in_batch, |
| cu_seqlens_kv=cu_seqlens_padded, |
| max_seqlen_kv=max_seqlen_in_batch, |
| cu_seqlens_q_padded=cu_seqlens_padded, |
| cu_seqlens_kv_padded=cu_seqlens_padded, |
| ) |
| if pre_process: |
| return input_ids_rmpad.unsqueeze(0), packed_seq_params |
| else: |
| return input_ids, packed_seq_params |
|
|
|
|
| def postprocess_packed_seqs( |
| output: torch.Tensor, |
| packed_seq_params: PackedSeqParams, |
| attention_mask: torch.Tensor, |
| batch_size: int, |
| seq_len: int, |
| post_process: bool = True, |
| ) -> torch.Tensor: |
| """ |
| Postprocess packed sequences |
| """ |
| if not post_process: |
| return output |
|
|
| |
| |
| |
| |
| cu_padded_cpu: list[int] = packed_seq_params.cu_seqlens_q_padded.tolist() |
| seq_lens_cpu: list[int] = attention_mask.sum(dim=1, dtype=torch.int32).cpu().tolist() |
|
|
| shape = [batch_size, seq_len] + list(output.shape[2:]) |
| output_new = torch.zeros(shape, dtype=output.dtype, device=output.device) |
|
|
| cp_size = mpu.get_context_parallel_world_size() |
| |
| if cp_size > 1: |
| |
| |
| output_list = [torch.empty_like(output, dtype=output.dtype) for _ in range(cp_size)] |
| torch.distributed.all_gather(output_list, output.detach(), group=mpu.get_context_parallel_group()) |
| output_list[mpu.get_context_parallel_rank()] = output |
| else: |
| output_list = [output] |
| for i in range(batch_size): |
| if cp_size <= 1: |
| s = seq_lens_cpu[i] |
| start_idx = cu_padded_cpu[i] |
| output_new[i, attention_mask[i]] = output[0][start_idx : start_idx + s] |
| continue |
| s_len_padded_chunk = (cu_padded_cpu[i + 1] - cu_padded_cpu[i]) // cp_size |
| half_seqlen = s_len_padded_chunk // 2 |
| s_len = seq_lens_cpu[i] |
| s_len_padded = s_len_padded_chunk * cp_size |
| tmp = torch.empty(s_len_padded, *output.shape[2:], device=output.device, dtype=output.dtype) |
| for j in range(cp_size): |
| o = output_list[j][0] |
| |
| packed_start_idx = cu_padded_cpu[i] // cp_size |
| o0, o1 = ( |
| o[packed_start_idx : packed_start_idx + half_seqlen], |
| o[packed_start_idx + half_seqlen : packed_start_idx + s_len_padded_chunk], |
| ) |
| tmp[j * half_seqlen : (j + 1) * half_seqlen] = o0 |
| tmp[s_len_padded - (j + 1) * half_seqlen : s_len_padded - j * half_seqlen] = o1 |
| output_new[i, attention_mask[i]] = tmp[:s_len] |
|
|
| return output_new |
|
|
|
|
| def preprocess_bshd( |
| input_ids: torch.Tensor, |
| attention_mask: torch.Tensor, |
| position_ids: torch.Tensor, |
| sequence_parallel: bool = False, |
| pre_process: bool = True, |
| ): |
| """ |
| Remove left padding from input_ids, attention_mask and position_ids |
| return new_input_ids, new_attention_mask, new_position_ids |
| """ |
| assert attention_mask.ndim == 2 |
| assert position_ids.ndim == 2 |
| cp_size = mpu.get_context_parallel_world_size() |
| assert cp_size == 1, "Context parallel size without seq_pack is not supported" |
| batch_size = input_ids.shape[0] |
| shape = list(input_ids.shape) |
| seq_lens = attention_mask.sum(dim=1) |
| seq_len = seq_lens.max().item() |
| if sequence_parallel: |
| sp_world_size = mpu.get_tensor_model_parallel_world_size() |
| pad_size = (sp_world_size - seq_len % sp_world_size) % sp_world_size |
| seq_len = seq_len + pad_size |
| shape[1] = seq_len |
| if pre_process: |
| new_input_ids = torch.zeros(dtype=input_ids.dtype, device=input_ids.device, size=shape) |
| new_attention_mask = torch.zeros( |
| dtype=attention_mask.dtype, device=attention_mask.device, size=(batch_size, seq_len) |
| ) |
| new_position_ids = torch.zeros(dtype=position_ids.dtype, device=position_ids.device, size=(batch_size, seq_len)) |
| for i in range(batch_size): |
| if pre_process: |
| new_input_ids[i, : seq_lens[i]] = input_ids[i, attention_mask[i]] |
| new_attention_mask[i, : seq_lens[i]] = attention_mask[i, attention_mask[i]] |
| new_position_ids[i, : seq_lens[i]] = position_ids[i, attention_mask[i]] |
| if pre_process: |
| return new_input_ids, new_attention_mask, new_position_ids |
| else: |
| return input_ids, new_attention_mask, new_position_ids |
|
|
|
|
| def postprocess_bshd( |
| result, |
| attention_mask: torch.Tensor, |
| original_attention_mask: torch.Tensor, |
| origin_seqlen: int, |
| post_process: bool = True, |
| ): |
| """ |
| Recover left padding from result |
| return result |
| """ |
| if not post_process: |
| return result |
| shape = list(result.shape) |
| batch_size = shape[0] |
| shape[1] = origin_seqlen |
| new_result = torch.zeros(dtype=result.dtype, device=result.device, size=shape) |
| for i in range(batch_size): |
| new_result[i, original_attention_mask[i]] = result[i, attention_mask[i]] |
| return new_result |
|
|
|
|
| def postprocess_packed_seqs_for_dict_output( |
| labels_mask: torch.Tensor, |
| output: CausalLMOutputForPPO, |
| packed_seq_params: PackedSeqParams, |
| attention_mask: torch.Tensor, |
| batch_size: int, |
| seq_len: int, |
| post_process: bool = True, |
| ) -> dict[str, torch.Tensor]: |
| """_summary_ |
| For fused kernels, the output is a dictionary with keys like 'log_probs', 'entropy', etc. |
| This function post-processes each tensor in the output dictionary. |
| Args: |
| output (CausalLMOutputForPPO): _description_ |
| packed_seq_params (PackedSeqParams): _description_ |
| attention_mask (torch.Tensor): _description_ |
| batch_size (int): _description_ |
| seq_len (int): _description_ |
| post_process (bool, optional): _description_. Defaults to True. |
| Returns: |
| CausalLMOutputForPPO: _description_ |
| """ |
| ret = {} |
| output.entropy = output.entropy.view(1, -1) |
| output.log_probs = output.log_probs.view(1, -1) |
| output.log_probs = output.log_probs.masked_fill(~labels_mask, 0.0) |
| ret["entropy"] = postprocess_packed_seqs( |
| output.entropy, packed_seq_params, attention_mask, batch_size, seq_len, post_process=post_process |
| ) |
| ret["log_probs"] = postprocess_packed_seqs( |
| output.log_probs, packed_seq_params, attention_mask, batch_size, seq_len, post_process=post_process |
| ) |
| return ret |
|
|
|
|
| |
| |
| def preprocess_thd_no_padding( |
| input_ids: torch.Tensor, pre_process: bool = True, need_roll: bool = False, use_fp8_padding: bool = False |
| ) -> tuple[torch.Tensor, PackedSeqParams, Optional[torch.Tensor]]: |
| """ |
| Preprocess packed sequences |
| CP splits sequence into CP*2 chunks, and each GPU gets 2 chunks (GPU0 gets first and last chunks, GPU1 |
| gets second and second last chunks, and so on), this is for load balancing with causal masking. |
| See https://github.com/NVIDIA/TransformerEngine/issues/1368 |
| """ |
| batch_size = input_ids.shape[0] |
|
|
| tp_size = mpu.get_tensor_model_parallel_world_size() |
| cp_size = mpu.get_context_parallel_world_size() |
| cp_rank = mpu.get_context_parallel_rank() |
| align_size = tp_size * cp_size * 2 if cp_size > 1 else tp_size |
| seqlens_in_batch = input_ids.offsets().diff() |
|
|
| if use_fp8_padding: |
| per_seq_align, total_align = _compute_fp8_thd_align_size(align_size) |
| align_size = per_seq_align |
|
|
| pad_size = (align_size - seqlens_in_batch % align_size) % align_size |
| seqlens_in_batch_padded = seqlens_in_batch + pad_size |
|
|
| cu_seqlens = torch.zeros(batch_size + 1, dtype=torch.int32, device=input_ids.device) |
| cu_seqlens[1:] = torch.cumsum(seqlens_in_batch, dim=0) |
| cu_seqlens_padded = torch.zeros(batch_size + 1, dtype=torch.int32, device=input_ids.device) |
| cu_seqlens_padded[1:] = torch.cumsum(seqlens_in_batch_padded, dim=0) |
|
|
| if use_fp8_padding: |
| |
| pad_size_last = (total_align - cu_seqlens_padded[-1] % total_align) % total_align |
| cu_seqlens_padded[-1] += pad_size_last |
| seqlens_in_batch_padded[-1] += pad_size_last |
|
|
| |
| |
| |
| |
| seqlens_in_batch_cpu: list[int] = seqlens_in_batch.tolist() |
| seqlens_in_batch_padded_cpu: list[int] = seqlens_in_batch_padded.tolist() |
| cu_seqlens_padded_cpu: list[int] = cu_seqlens_padded.tolist() |
|
|
| |
| max_seqlen_in_batch = max(seqlens_in_batch_padded_cpu) |
|
|
| shape = list(input_ids.shape[1:]) |
| shape[0] = sum(seqlens_in_batch_padded_cpu) // cp_size |
| if pre_process: |
| input_ids_rmpad = torch.zeros(shape, dtype=input_ids.dtype, device=input_ids.device) |
| position_ids_rmpad = torch.zeros(shape, dtype=torch.long, device=input_ids.device) |
| if need_roll: |
| saved_roll_dict = {} |
| saved_position_roll_dict = {} |
| for i in range(batch_size): |
| |
| if cp_size <= 1: |
| seqlen = seqlens_in_batch_cpu[i] |
| start_idx = cu_seqlens_padded_cpu[i] |
| input_ids_rmpad[start_idx : start_idx + seqlen] = input_ids[i] |
| |
| position_ids_rmpad[start_idx : start_idx + seqlen] = torch.arange( |
| seqlen, dtype=torch.long, device=input_ids.device |
| ) |
| continue |
|
|
| seqlen_padded_i = seqlens_in_batch_padded_cpu[i] |
| seqlen = seqlen_padded_i // cp_size |
| half_seqlen = seqlen // 2 |
| start_idx = cu_seqlens_padded_cpu[i] // cp_size |
| |
| d = input_ids[i] |
| |
| |
| |
| |
| if d.numel() < align_size: |
| original_size = d.numel() |
| pad = torch.zeros(align_size - d.numel(), dtype=d.dtype, device=d.device) |
| d = torch.cat([d, pad], dim=0) |
| logger.warning_once( |
| f"Padding tensor for context parallel alignment, original_size={original_size}, " |
| f"align_size={align_size}" |
| ) |
|
|
| input_ids_rmpad[start_idx : start_idx + half_seqlen] = d[ |
| half_seqlen * cp_rank : half_seqlen * (cp_rank + 1) |
| ] |
|
|
| |
| position_ids_rmpad[start_idx : start_idx + half_seqlen] = torch.arange( |
| half_seqlen * cp_rank, half_seqlen * (cp_rank + 1), dtype=torch.long, device=input_ids.device |
| ) |
|
|
| remain_start = seqlen_padded_i - half_seqlen * (cp_rank + 1) |
| remain_end = seqlen_padded_i - half_seqlen * cp_rank |
| remain_end = min(remain_end, d.shape[0]) |
| remain_len = remain_end - remain_start |
| if remain_len > 0: |
| input_ids_rmpad[start_idx + half_seqlen : start_idx + half_seqlen + remain_len] = d[ |
| remain_start:remain_end |
| ] |
| |
| position_ids_rmpad[start_idx + half_seqlen : start_idx + half_seqlen + remain_len] = torch.arange( |
| seqlen_padded_i - remain_len, seqlen_padded_i, dtype=torch.long, device=input_ids.device |
| ) |
|
|
| if need_roll: |
| |
| saved_roll_dict[start_idx + half_seqlen - 1] = d[(cp_rank + 1) * half_seqlen] |
| saved_position_roll_dict[start_idx + half_seqlen - 1] = position_ids_rmpad[start_idx + half_seqlen - 1] |
| if remain_len > 0: |
| if remain_end == d.shape[0]: |
| saved_roll_dict[start_idx + half_seqlen + remain_len - 1] = d[0] |
| saved_position_roll_dict[start_idx + half_seqlen + remain_len - 1] = 0 |
| else: |
| saved_roll_dict[start_idx + half_seqlen + remain_len - 1] = d[remain_end] |
| saved_position_roll_dict[start_idx + half_seqlen + remain_len - 1] = position_ids_rmpad[ |
| start_idx + half_seqlen + remain_len - 1 |
| ] |
|
|
| if need_roll: |
| input_ids_rmpad = torch.roll(input_ids_rmpad, shifts=-1, dims=0) |
| position_ids_rmpad = torch.roll(position_ids_rmpad, shifts=-1, dims=0) |
| if len(saved_roll_dict) > 0: |
| for k, v in saved_roll_dict.items(): |
| input_ids_rmpad[k] = v |
| for k, v in saved_position_roll_dict.items(): |
| position_ids_rmpad[k] = v |
|
|
| packed_seq_params = PackedSeqParams( |
| qkv_format="thd", |
| cu_seqlens_q=cu_seqlens_padded, |
| max_seqlen_q=max_seqlen_in_batch, |
| cu_seqlens_kv=cu_seqlens_padded, |
| max_seqlen_kv=max_seqlen_in_batch, |
| cu_seqlens_q_padded=cu_seqlens_padded, |
| cu_seqlens_kv_padded=cu_seqlens_padded, |
| ) |
| if pre_process: |
| return input_ids_rmpad.unsqueeze(0), packed_seq_params, position_ids_rmpad.unsqueeze(0) |
| else: |
| return input_ids, packed_seq_params, None |
|
|
|
|
| def postprocess_thd_no_padding( |
| output: torch.Tensor, |
| packed_seq_params: PackedSeqParams, |
| input_ids: torch.Tensor, |
| batch_size: int, |
| post_process: bool = True, |
| ) -> torch.Tensor: |
| """ |
| Postprocess packed sequences |
| """ |
| if not post_process: |
| return output |
|
|
| |
| |
| |
| |
| cu_padded_cpu: list[int] = packed_seq_params.cu_seqlens_q_padded.tolist() |
| |
| |
| cu_seqlens = input_ids.offsets() |
| seq_lens_cpu: list[int] = cu_seqlens.diff().tolist() |
|
|
| output_new = [] |
|
|
| cp_size = mpu.get_context_parallel_world_size() |
| |
| if cp_size > 1: |
| |
| |
| output_list = [torch.empty_like(output) for _ in range(cp_size)] |
| torch.distributed.all_gather(output_list, output.detach(), group=mpu.get_context_parallel_group()) |
| output_list[mpu.get_context_parallel_rank()] = output |
| else: |
| output_list = [output] |
|
|
| for i in range(batch_size): |
| if cp_size <= 1: |
| s = seq_lens_cpu[i] |
| start_idx = cu_padded_cpu[i] |
| output_new.append(output[0][start_idx : start_idx + s]) |
| continue |
| s_len_padded_chunk = (cu_padded_cpu[i + 1] - cu_padded_cpu[i]) // cp_size |
| half_seqlen = s_len_padded_chunk // 2 |
| s_len = seq_lens_cpu[i] |
| s_len_padded = s_len_padded_chunk * cp_size |
| tmp = torch.empty(s_len_padded, *output.shape[2:], device=output.device) |
| for j in range(cp_size): |
| o = output_list[j][0] |
| |
| packed_start_idx = cu_padded_cpu[i] // cp_size |
| o0, o1 = ( |
| o[packed_start_idx : packed_start_idx + half_seqlen], |
| o[packed_start_idx + half_seqlen : packed_start_idx + s_len_padded_chunk], |
| ) |
| tmp[j * half_seqlen : (j + 1) * half_seqlen] = o0 |
| tmp[s_len_padded - (j + 1) * half_seqlen : s_len_padded - j * half_seqlen] = o1 |
| output_new.append(tmp[:s_len]) |
|
|
| output_new_tensor = torch.nested.as_nested_tensor(output_new, layout=torch.jagged) |
|
|
| return output_new_tensor |
|
|
|
|
| def preprocess_bshd_no_padding( |
| input_ids: torch.Tensor, pre_process: bool = True, need_roll: bool = False, use_fp8_padding: bool = False |
| ): |
| """ |
| Preprocess bshd sequences |
| return "input_ids, attention_mask, position_ids" |
| """ |
| cp_size = mpu.get_context_parallel_world_size() |
| |
| assert cp_size == 1, "Context parallel size without bshd is not supported yet" |
|
|
| batch_size = input_ids.shape[0] |
| seqlens_in_batch = input_ids.offsets().diff() |
| max_seqlen = seqlens_in_batch.max().item() |
| tp_size = mpu.get_tensor_model_parallel_world_size() |
| if tp_size > 1: |
| sp_world_size = tp_size |
| pad_size = (sp_world_size - max_seqlen % sp_world_size) % sp_world_size |
| max_seqlen = max_seqlen + pad_size |
| if use_fp8_padding: |
| |
| |
| |
| fp8_total_align = 128 * tp_size |
| fp8_seq_align = fp8_total_align // math.gcd(batch_size, fp8_total_align) |
| |
| fp8_seq_align = math.lcm(fp8_seq_align, tp_size) |
| max_seqlen = ((max_seqlen + fp8_seq_align - 1) // fp8_seq_align) * fp8_seq_align |
|
|
| attention_mask = torch.zeros(batch_size, max_seqlen, dtype=torch.bool, device=input_ids.device) |
| input_ids_bshd = torch.zeros(batch_size, max_seqlen, dtype=input_ids.dtype, device=input_ids.device) |
| for i in range(batch_size): |
| attention_mask[i, : seqlens_in_batch[i]] = True |
| input_ids_bshd[i, : seqlens_in_batch[i]] = input_ids[i] |
| position_ids = torch.arange(max_seqlen, dtype=torch.long, device=input_ids.device) |
| position_ids = position_ids.unsqueeze(0).expand_as(input_ids_bshd) |
| if need_roll: |
| input_ids_bshd = torch.roll(input_ids_bshd, shifts=-1, dims=1) |
|
|
| return input_ids_bshd, attention_mask, position_ids |
|
|
|
|
| def postprocess_bshd_no_padding( |
| output: torch.Tensor, |
| attention_mask: torch.Tensor, |
| post_process: bool = True, |
| ) -> torch.Tensor: |
| """ |
| Postprocess bshd sequences |
| """ |
| if not post_process: |
| return output |
|
|
| batch_size = output.shape[0] |
| output_new = [] |
|
|
| for i in range(batch_size): |
| mask = attention_mask[i].bool() |
| output_new.append(output[i][mask]) |
|
|
| output_new_tensor = torch.nested.as_nested_tensor(output_new, layout=torch.jagged) |
|
|
| return output_new_tensor |
|
|