# Copyright 2025 Bytedance Ltd. and/or its affiliates # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import torch import torch.nn.functional as F from tensordict import TensorDict from verl.utils import tensordict_utils as tu from verl.utils.attention_utils import index_first_axis, unpad_input def left_right_2_no_padding(data: TensorDict) -> TensorDict: """ Convert TensorDict from left-right padding to no-padding format. Args: data: TensorDict with "input_ids", "attention_mask", "response_mask", "position_ids" Returns: data: TensorDict with - Tensor includes NestedTensors like "input_ids", "loss_mask", "position_ids" - NonTensorData includes "max_seq_len", "max_response_len", "indices" Note: 1. the return input_ids/position_ids/loss_mask are nested tensor. 2. we will remove "attention_mask", "response" in the return data, but "response_mask" is kept. """ assert "input_ids" in data, "input_ids is required in left-right padding data" assert "attention_mask" in data, "attention_mask is required in left-right padding data" assert "response_mask" in data, "response_mask is required in left-right padding data" assert "position_ids" in data, "position_ids is required in left-right padding data" input_ids = data.pop("input_ids") attention_mask = data["attention_mask"] response_mask = data["response_mask"] position_ids = data["position_ids"] # (bs, seq_len) or # (bs, 4, seq_len) max_seq_len, max_response_len = input_ids.shape[1], response_mask.shape[1] tu.assign_non_tensor_data(data, "max_seq_len", max_seq_len) tu.assign_non_tensor_data(data, "max_response_len", max_response_len) input_ids_rmpad, indices, cu_seqlens, *_ = unpad_input(input_ids.unsqueeze(-1), attention_mask) tu.assign_non_tensor_data(data, "indices", indices) input_ids_nested = torch.nested.nested_tensor_from_jagged(input_ids_rmpad.squeeze(-1), offsets=cu_seqlens) position_ids_list = [] for i in range(attention_mask.shape[0]): curr_mask = attention_mask[i].bool() curr_pos_ids = position_ids[i] if curr_pos_ids.dim() == 1: # (seq_len,) valid_ids = curr_pos_ids[curr_mask] else: # (4, seq_len) valid_ids = curr_pos_ids[:, curr_mask] position_ids_list.append(valid_ids) position_ids_nested = torch.nested.as_nested_tensor(position_ids_list, layout=torch.jagged) data["input_ids"] = input_ids_nested data["position_ids"] = position_ids_nested data["loss_mask"] = data["response_mask"] routed_experts = data.get("routed_experts", None) if routed_experts is not None and not routed_experts.is_nested: if routed_experts.max() <= 255: routed_experts = routed_experts.to(torch.uint8) routed_experts_rmpad = index_first_axis(routed_experts.unsqueeze(-1).flatten(0, 1), indices) routed_experts_nested = torch.nested.nested_tensor_from_jagged( routed_experts_rmpad.squeeze(-1), offsets=cu_seqlens ) data["routed_experts"] = routed_experts_nested return data def no_padding_2_padding(tensor: torch.Tensor, data: TensorDict) -> torch.Tensor: """Slice response from unpad model output. Args: tensor: a nested tensor or a 1D tensor in shape (total_nnz,), total_nnz is the total number of tokens across all sequences in the batch data: TensorDict with "prompts", "responses", "attention_mask" Returns: tensor: sliced response tensor of shape [bsz, max_response_len] """ values = tensor.values() if tensor.is_nested else tensor prompt_ids = data["prompts"] response_ids = data["responses"] attention_mask = data["attention_mask"] max_response_len = tu.get_non_tensor_data(data=data, key="max_response_len", default=-1) if prompt_ids.is_nested: prompt_lens = prompt_ids.offsets().diff() response_lens = response_ids.offsets().diff() if max_response_len < 0: max_response_len = response_lens.max().item() else: assert not attention_mask.is_nested prompt_lens = attention_mask[:, : prompt_ids.shape[1]].sum(dim=1) response_lens = attention_mask[:, prompt_ids.shape[1] :].sum(dim=1) max_response_len = response_ids.shape[1] sequence_lens = prompt_lens + response_lens sequence_offsets = sequence_lens.cumsum(dim=0) assert sequence_offsets[-1].item() == values.shape[0] response_list = [] for resp_len, seq_offset in zip(response_lens, sequence_offsets, strict=True): pad_size = max_response_len - resp_len # left-shift model output by one token for log_probs/values response_list.append(F.pad(values[seq_offset - resp_len - 1 : seq_offset - 1], (0, pad_size))) output = torch.stack(response_list, dim=0) return output