| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import torch |
| import torch.nn.functional as F |
| from tensordict import TensorDict |
|
|
| from verl.utils import tensordict_utils as tu |
| from verl.utils.attention_utils import index_first_axis, unpad_input |
|
|
|
|
| def left_right_2_no_padding(data: TensorDict) -> TensorDict: |
| """ |
| Convert TensorDict from left-right padding to no-padding format. |
| |
| Args: |
| data: TensorDict with "input_ids", "attention_mask", "response_mask", "position_ids" |
| |
| Returns: |
| data: TensorDict with |
| - Tensor includes NestedTensors like "input_ids", "loss_mask", "position_ids" |
| - NonTensorData includes "max_seq_len", "max_response_len", "indices" |
| |
| Note: |
| 1. the return input_ids/position_ids/loss_mask are nested tensor. |
| 2. we will remove "attention_mask", "response" in the return data, but "response_mask" is kept. |
| """ |
| assert "input_ids" in data, "input_ids is required in left-right padding data" |
| assert "attention_mask" in data, "attention_mask is required in left-right padding data" |
| assert "response_mask" in data, "response_mask is required in left-right padding data" |
| assert "position_ids" in data, "position_ids is required in left-right padding data" |
|
|
| input_ids = data.pop("input_ids") |
| attention_mask = data["attention_mask"] |
| response_mask = data["response_mask"] |
| position_ids = data["position_ids"] |
|
|
| max_seq_len, max_response_len = input_ids.shape[1], response_mask.shape[1] |
| tu.assign_non_tensor_data(data, "max_seq_len", max_seq_len) |
| tu.assign_non_tensor_data(data, "max_response_len", max_response_len) |
|
|
| input_ids_rmpad, indices, cu_seqlens, *_ = unpad_input(input_ids.unsqueeze(-1), attention_mask) |
| tu.assign_non_tensor_data(data, "indices", indices) |
|
|
| input_ids_nested = torch.nested.nested_tensor_from_jagged(input_ids_rmpad.squeeze(-1), offsets=cu_seqlens) |
|
|
| position_ids_list = [] |
| for i in range(attention_mask.shape[0]): |
| curr_mask = attention_mask[i].bool() |
| curr_pos_ids = position_ids[i] |
| if curr_pos_ids.dim() == 1: |
| valid_ids = curr_pos_ids[curr_mask] |
| else: |
| valid_ids = curr_pos_ids[:, curr_mask] |
| position_ids_list.append(valid_ids) |
| position_ids_nested = torch.nested.as_nested_tensor(position_ids_list, layout=torch.jagged) |
|
|
| data["input_ids"] = input_ids_nested |
| data["position_ids"] = position_ids_nested |
| data["loss_mask"] = data["response_mask"] |
|
|
| routed_experts = data.get("routed_experts", None) |
| if routed_experts is not None and not routed_experts.is_nested: |
| if routed_experts.max() <= 255: |
| routed_experts = routed_experts.to(torch.uint8) |
| routed_experts_rmpad = index_first_axis(routed_experts.unsqueeze(-1).flatten(0, 1), indices) |
| routed_experts_nested = torch.nested.nested_tensor_from_jagged( |
| routed_experts_rmpad.squeeze(-1), offsets=cu_seqlens |
| ) |
| data["routed_experts"] = routed_experts_nested |
|
|
| return data |
|
|
|
|
| def no_padding_2_padding(tensor: torch.Tensor, data: TensorDict) -> torch.Tensor: |
| """Slice response from unpad model output. |
| |
| Args: |
| tensor: a nested tensor or a 1D tensor in shape (total_nnz,), |
| total_nnz is the total number of tokens across all sequences in the batch |
| data: TensorDict with "prompts", "responses", "attention_mask" |
| |
| Returns: |
| tensor: sliced response tensor of shape [bsz, max_response_len] |
| """ |
| values = tensor.values() if tensor.is_nested else tensor |
| prompt_ids = data["prompts"] |
| response_ids = data["responses"] |
| attention_mask = data["attention_mask"] |
|
|
| max_response_len = tu.get_non_tensor_data(data=data, key="max_response_len", default=-1) |
|
|
| if prompt_ids.is_nested: |
| prompt_lens = prompt_ids.offsets().diff() |
| response_lens = response_ids.offsets().diff() |
| if max_response_len < 0: |
| max_response_len = response_lens.max().item() |
| else: |
| assert not attention_mask.is_nested |
| prompt_lens = attention_mask[:, : prompt_ids.shape[1]].sum(dim=1) |
| response_lens = attention_mask[:, prompt_ids.shape[1] :].sum(dim=1) |
| max_response_len = response_ids.shape[1] |
|
|
| sequence_lens = prompt_lens + response_lens |
| sequence_offsets = sequence_lens.cumsum(dim=0) |
| assert sequence_offsets[-1].item() == values.shape[0] |
|
|
| response_list = [] |
| for resp_len, seq_offset in zip(response_lens, sequence_offsets, strict=True): |
| pad_size = max_response_len - resp_len |
| |
| response_list.append(F.pad(values[seq_offset - resp_len - 1 : seq_offset - 1], (0, pad_size))) |
|
|
| output = torch.stack(response_list, dim=0) |
| return output |
|
|