arcisvlm / tests /test_data.py
Hardik Sanghvi
feat: integrate Gemma 4 E2B backbone for production-quality VLM inference
7a564e3
Raw
History Blame Contribute Delete
4.15 kB
"""Tests for multi-dataset pipeline."""
import unittest
import torch
import sys
import os
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from data.formatters import format_for_task, classify_dataset_task
from data.multi_dataset import UnifiedVLMDataset, DatasetMixer, build_dummy_dataset
class TestFormatters(unittest.TestCase):
def test_vqa_format(self):
result = format_for_task("What is this?", "a cat", "vqa")
assert "single word or phrase" in result["instruction"]
assert result["task_token"] == "<vqa>"
assert result["answer"] == "a cat"
def test_mcq_format(self):
result = format_for_task("What? A) cat B) dog", "A", "mcq")
assert "option's letter" in result["instruction"]
def test_caption_format(self):
result = format_for_task("", "A beautiful sunset", "caption")
assert result["task_token"] == "<caption>"
def test_detection_format(self):
result = format_for_task("", "person, car", "detection")
assert "objects" in result["instruction"].lower()
assert result["task_token"] == "<detect>"
def test_alert_format(self):
result = format_for_task("", "intruder detected", "alert")
assert result["task_token"] == "<alert>"
def test_classify_datasets(self):
assert classify_dataset_task("vqav2") == "vqa"
assert classify_dataset_task("textvqa") == "ocr"
assert classify_dataset_task("coco_detect") == "detection"
assert classify_dataset_task("ucf_crime") == "alert"
assert classify_dataset_task("llava_instruct") == "conversation"
class TestUnifiedDataset(unittest.TestCase):
def test_dummy_samples(self):
samples = build_dummy_dataset("vqav2", num_samples=10, img_size=56)
assert len(samples) == 10
assert "image" in samples[0]
assert "question" in samples[0]
assert "answer" in samples[0]
def test_unified_dataset_no_tokenizer(self):
samples = build_dummy_dataset("vqav2", num_samples=5, img_size=56)
ds = UnifiedVLMDataset(samples, "vqav2", img_size=56)
assert len(ds) == 5
item = ds[0]
assert "image" in item
assert "instruction" in item
assert "answer" in item
assert "task_token" in item
def test_unified_dataset_with_tokenizer(self):
from model.tokenizer import BPETokenizer
tok = BPETokenizer(vocab_size=512)
tok.train(["hello world test data question answer"] * 20)
samples = build_dummy_dataset("vqav2", num_samples=5, img_size=56)
ds = UnifiedVLMDataset(samples, "vqav2", tokenizer=tok, img_size=56)
item = ds[0]
assert "instruction_ids" in item
assert "answer_ids" in item
assert isinstance(item["instruction_ids"], torch.Tensor)
class TestDatasetMixer(unittest.TestCase):
def test_mixer_basic(self):
ds_a = build_dummy_dataset("vqav2", 50, 56)
ds_b = build_dummy_dataset("coco_detect", 30, 56)
ua = UnifiedVLMDataset(ds_a, "vqav2", img_size=56)
ub = UnifiedVLMDataset(ds_b, "coco_detect", img_size=56)
mixer = DatasetMixer(
datasets={"vqav2": ua, "coco_detect": ub},
weights={"vqav2": 0.7, "coco_detect": 0.3},
total_samples=100,
)
assert len(mixer) == 100
item = mixer[0]
assert "instruction" in item
def test_mixer_all_datasets_sampled(self):
ds_a = build_dummy_dataset("vqav2", 50, 56)
ds_b = build_dummy_dataset("coco_detect", 50, 56)
ua = UnifiedVLMDataset(ds_a, "vqav2", img_size=56)
ub = UnifiedVLMDataset(ds_b, "coco_detect", img_size=56)
mixer = DatasetMixer(
datasets={"vqav2": ua, "coco_detect": ub},
total_samples=200,
)
# Sample many items, check we get both task tokens
tokens = set()
for i in range(min(200, len(mixer))):
tokens.add(mixer[i]["task_token"])
assert len(tokens) >= 2, f"Expected multiple task types, got {tokens}"
if __name__ == "__main__":
unittest.main()