# Copyright 2024 Bytedance Ltd. and/or its affiliates # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import numpy as np import torch from verl.protocol import DataProto from verl.utils.seqlen_balancing import prepare_dynamic_batch, restore_dynamic_batch def _create_random_mask( input_ids: torch.Tensor, max_ratio_of_valid_token: float, max_ratio_of_left_padding: float, min_ratio_of_valid_token: float = 0, ) -> torch.Tensor: """Create a random mask given input_ids. Support left padding and right padding. Process: - Sample valid token length - Sample left_padding length - Generate padding Args: input_ids: shape (batch_size, seq_len) Returns: mask: shape (batch_size, seq_len) """ assert max_ratio_of_valid_token > 0 and max_ratio_of_valid_token <= 1.0 assert max_ratio_of_left_padding >= 0 and max_ratio_of_left_padding < 1.0 assert min_ratio_of_valid_token <= max_ratio_of_valid_token batch_size, sequence_length = input_ids.shape max_num_valid_tokens = int(sequence_length * max_ratio_of_valid_token) min_num_valid_tokens = max(1, int(sequence_length * min_ratio_of_valid_token)) max_left_padding = int(sequence_length * max_ratio_of_left_padding) assert max_num_valid_tokens + max_left_padding <= sequence_length assert max_num_valid_tokens > 0 and max_ratio_of_valid_token <= sequence_length mask = torch.ones_like(input_ids, dtype=torch.int64) # TODO: we can make this faster for i in range(batch_size): num_left_padding = np.random.randint(low=0, high=max_left_padding + 1, dtype=np.int64) num_valid = np.random.randint(low=min_num_valid_tokens, high=max_num_valid_tokens + 1, dtype=np.int64) for index in range(num_left_padding): mask[i, index] = 0 for index in range(num_left_padding + num_valid, sequence_length): mask[i, index] = 0 return mask def test_dynamic_batch(): input_ids = torch.randint(low=0, high=10, size=(20, 100)) attention_mask = _create_random_mask( input_ids=input_ids, max_ratio_of_left_padding=0.1, max_ratio_of_valid_token=0.9, min_ratio_of_valid_token=0.5 ) data = {"input_ids": input_ids, "attention_mask": attention_mask} dataproto = DataProto.from_single_dict(data) micro_batches, micro_bsz_idx_lst = prepare_dynamic_batch(dataproto, max_token_len=300) input_ids = torch.cat([micro_batch.batch["input_ids"] for micro_batch in micro_batches], dim=0) input_ids = restore_dynamic_batch(input_ids, micro_bsz_idx_lst) torch.testing.assert_close(input_ids, dataproto.batch["input_ids"])