| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import numpy as np |
| import torch |
|
|
| from verl.protocol import DataProto |
| from verl.utils.seqlen_balancing import prepare_dynamic_batch, restore_dynamic_batch |
|
|
|
|
| def _create_random_mask( |
| input_ids: torch.Tensor, |
| max_ratio_of_valid_token: float, |
| max_ratio_of_left_padding: float, |
| min_ratio_of_valid_token: float = 0, |
| ) -> torch.Tensor: |
| """Create a random mask given input_ids. Support left padding and right padding. |
| |
| Process: |
| - Sample valid token length |
| - Sample left_padding length |
| - Generate padding |
| |
| Args: |
| input_ids: |
| shape (batch_size, seq_len) |
| |
| Returns: |
| mask: |
| shape (batch_size, seq_len) |
| """ |
| assert max_ratio_of_valid_token > 0 and max_ratio_of_valid_token <= 1.0 |
| assert max_ratio_of_left_padding >= 0 and max_ratio_of_left_padding < 1.0 |
| assert min_ratio_of_valid_token <= max_ratio_of_valid_token |
|
|
| batch_size, sequence_length = input_ids.shape |
| max_num_valid_tokens = int(sequence_length * max_ratio_of_valid_token) |
| min_num_valid_tokens = max(1, int(sequence_length * min_ratio_of_valid_token)) |
| max_left_padding = int(sequence_length * max_ratio_of_left_padding) |
| assert max_num_valid_tokens + max_left_padding <= sequence_length |
| assert max_num_valid_tokens > 0 and max_ratio_of_valid_token <= sequence_length |
| mask = torch.ones_like(input_ids, dtype=torch.int64) |
| |
| for i in range(batch_size): |
| num_left_padding = np.random.randint(low=0, high=max_left_padding + 1, dtype=np.int64) |
| num_valid = np.random.randint(low=min_num_valid_tokens, high=max_num_valid_tokens + 1, dtype=np.int64) |
|
|
| for index in range(num_left_padding): |
| mask[i, index] = 0 |
|
|
| for index in range(num_left_padding + num_valid, sequence_length): |
| mask[i, index] = 0 |
|
|
| return mask |
|
|
|
|
| def test_dynamic_batch(): |
| input_ids = torch.randint(low=0, high=10, size=(20, 100)) |
| attention_mask = _create_random_mask( |
| input_ids=input_ids, max_ratio_of_left_padding=0.1, max_ratio_of_valid_token=0.9, min_ratio_of_valid_token=0.5 |
| ) |
| data = {"input_ids": input_ids, "attention_mask": attention_mask} |
| dataproto = DataProto.from_single_dict(data) |
| micro_batches, micro_bsz_idx_lst = prepare_dynamic_batch(dataproto, max_token_len=300) |
| input_ids = torch.cat([micro_batch.batch["input_ids"] for micro_batch in micro_batches], dim=0) |
| input_ids = restore_dynamic_batch(input_ids, micro_bsz_idx_lst) |
| torch.testing.assert_close(input_ids, dataproto.batch["input_ids"]) |
|
|