|
|
import os |
|
|
from unittest import TestCase |
|
|
|
|
|
import torch |
|
|
|
|
|
from apps.plm.dataset_conf import dataset_config as DATASET_CONFIGS |
|
|
from apps.plm.tokenizer import PLMTokenizer |
|
|
from core.data.dataloader import DataloadArgs, get_dataloader |
|
|
|
|
|
|
|
|
|
|
|
class DataloaderTest(TestCase): |
|
|
def setUp(self): |
|
|
self.seq_len = 8196 |
|
|
self.patch_size = 14 |
|
|
self.pooling_ratio = 2 |
|
|
self.max_num_tiles = 9 |
|
|
self.image_res = 448 |
|
|
self.mllm_tokenizer = PLMTokenizer( |
|
|
os.environ["TOKENIZER_PATH"], |
|
|
seq_len=self.seq_len, |
|
|
patch_size=self.patch_size, |
|
|
pooling_ratio=self.pooling_ratio, |
|
|
) |
|
|
|
|
|
def test_jsonl_image_text_dataloader(self): |
|
|
dataloader_args = DataloadArgs( |
|
|
datamix="dummy_image:1", |
|
|
num_workers=1, |
|
|
vision_input_type="thumb+tile", |
|
|
image_res=self.image_res, |
|
|
max_num_tiles=self.max_num_tiles, |
|
|
batch_size=1, |
|
|
) |
|
|
dataloader = get_dataloader( |
|
|
dataloader_args, |
|
|
dp_rank=0, |
|
|
dp_world_size=1, |
|
|
dataset_configs=DATASET_CONFIGS, |
|
|
tokenizer=self.mllm_tokenizer, |
|
|
) |
|
|
batch_iterator = iter(dataloader) |
|
|
expected_num_image_tokens = ( |
|
|
self.image_res // self.patch_size // self.pooling_ratio |
|
|
) ** 2 * (self.max_num_tiles + 1) |
|
|
print(f"expected_num_image_tokens: {expected_num_image_tokens}") |
|
|
for i in range(3): |
|
|
mllm_batch = next(batch_iterator) |
|
|
|
|
|
image_token_mask = mllm_batch.x == self.mllm_tokenizer.image_token_id |
|
|
num_image_tokens = image_token_mask.sum(dim=1) |
|
|
self.assertTrue((mllm_batch.image_pos_index[~image_token_mask] == -1).all()) |
|
|
|
|
|
for i in range(mllm_batch.x.shape[0]): |
|
|
print(f"num_image_tokens in example {i}", num_image_tokens[i]) |
|
|
self.assertEqual( |
|
|
num_image_tokens[i], |
|
|
expected_num_image_tokens, |
|
|
) |
|
|
cur_x_is_image = image_token_mask[i] |
|
|
cur_image_pos_index = mllm_batch.image_pos_index[i] |
|
|
self.assertTrue( |
|
|
torch.equal( |
|
|
cur_image_pos_index[cur_x_is_image], |
|
|
torch.arange(num_image_tokens[i]), |
|
|
) |
|
|
) |
|
|
|