MoTIF / utils /core /tests /dataloader_test.py
P4ddyki's picture
Upload folder using huggingface_hub
3cf4fff verified
import os
from unittest import TestCase
import torch
from apps.plm.dataset_conf import dataset_config as DATASET_CONFIGS
from apps.plm.tokenizer import PLMTokenizer
from core.data.dataloader import DataloadArgs, get_dataloader
# TOKENIZER_PATH=facebook/Perception-LM-1B/tokenizer.model python -m unittest core/tests/dataloader_test.py
class DataloaderTest(TestCase):
def setUp(self):
self.seq_len = 8196
self.patch_size = 14
self.pooling_ratio = 2
self.max_num_tiles = 9
self.image_res = 448
self.mllm_tokenizer = PLMTokenizer(
os.environ["TOKENIZER_PATH"],
seq_len=self.seq_len,
patch_size=self.patch_size,
pooling_ratio=self.pooling_ratio,
)
def test_jsonl_image_text_dataloader(self):
dataloader_args = DataloadArgs(
datamix="dummy_image:1",
num_workers=1,
vision_input_type="thumb+tile",
image_res=self.image_res,
max_num_tiles=self.max_num_tiles,
batch_size=1,
)
dataloader = get_dataloader(
dataloader_args,
dp_rank=0,
dp_world_size=1,
dataset_configs=DATASET_CONFIGS,
tokenizer=self.mllm_tokenizer,
)
batch_iterator = iter(dataloader)
expected_num_image_tokens = (
self.image_res // self.patch_size // self.pooling_ratio
) ** 2 * (self.max_num_tiles + 1)
print(f"expected_num_image_tokens: {expected_num_image_tokens}")
for i in range(3):
mllm_batch = next(batch_iterator)
image_token_mask = mllm_batch.x == self.mllm_tokenizer.image_token_id
num_image_tokens = image_token_mask.sum(dim=1)
self.assertTrue((mllm_batch.image_pos_index[~image_token_mask] == -1).all())
for i in range(mllm_batch.x.shape[0]):
print(f"num_image_tokens in example {i}", num_image_tokens[i])
self.assertEqual(
num_image_tokens[i],
expected_num_image_tokens,
)
cur_x_is_image = image_token_mask[i]
cur_image_pos_index = mllm_batch.image_pos_index[i]
self.assertTrue(
torch.equal(
cur_image_pos_index[cur_x_is_image],
torch.arange(num_image_tokens[i]),
)
)