RL / model /EasyR1 /tests /test_dynamic_batch.py
WangYe007's picture
Upload folder using huggingface_hub
d65b589 verified
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import torch
from verl.protocol import DataProto
from verl.utils.seqlen_balancing import prepare_dynamic_batch, restore_dynamic_batch
def _create_random_mask(
input_ids: torch.Tensor,
max_ratio_of_valid_token: float,
max_ratio_of_left_padding: float,
min_ratio_of_valid_token: float = 0,
) -> torch.Tensor:
"""Create a random mask given input_ids. Support left padding and right padding.
Process:
- Sample valid token length
- Sample left_padding length
- Generate padding
Args:
input_ids:
shape (batch_size, seq_len)
Returns:
mask:
shape (batch_size, seq_len)
"""
assert max_ratio_of_valid_token > 0 and max_ratio_of_valid_token <= 1.0
assert max_ratio_of_left_padding >= 0 and max_ratio_of_left_padding < 1.0
assert min_ratio_of_valid_token <= max_ratio_of_valid_token
batch_size, sequence_length = input_ids.shape
max_num_valid_tokens = int(sequence_length * max_ratio_of_valid_token)
min_num_valid_tokens = max(1, int(sequence_length * min_ratio_of_valid_token))
max_left_padding = int(sequence_length * max_ratio_of_left_padding)
assert max_num_valid_tokens + max_left_padding <= sequence_length
assert max_num_valid_tokens > 0 and max_ratio_of_valid_token <= sequence_length
mask = torch.ones_like(input_ids, dtype=torch.int64)
# TODO: we can make this faster
for i in range(batch_size):
num_left_padding = np.random.randint(low=0, high=max_left_padding + 1, dtype=np.int64)
num_valid = np.random.randint(low=min_num_valid_tokens, high=max_num_valid_tokens + 1, dtype=np.int64)
for index in range(num_left_padding):
mask[i, index] = 0
for index in range(num_left_padding + num_valid, sequence_length):
mask[i, index] = 0
return mask
def test_dynamic_batch():
input_ids = torch.randint(low=0, high=10, size=(20, 100))
attention_mask = _create_random_mask(
input_ids=input_ids, max_ratio_of_left_padding=0.1, max_ratio_of_valid_token=0.9, min_ratio_of_valid_token=0.5
)
data = {"input_ids": input_ids, "attention_mask": attention_mask}
dataproto = DataProto.from_single_dict(data)
micro_batches, micro_bsz_idx_lst = prepare_dynamic_batch(dataproto, max_token_len=300)
input_ids = torch.cat([micro_batch.batch["input_ids"] for micro_batch in micro_batches], dim=0)
input_ids = restore_dynamic_batch(input_ids, micro_bsz_idx_lst)
torch.testing.assert_close(input_ids, dataproto.batch["input_ids"])