"""Tests for multi-dataset pipeline.""" import unittest import torch import sys import os sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from data.formatters import format_for_task, classify_dataset_task from data.multi_dataset import UnifiedVLMDataset, DatasetMixer, build_dummy_dataset class TestFormatters(unittest.TestCase): def test_vqa_format(self): result = format_for_task("What is this?", "a cat", "vqa") assert "single word or phrase" in result["instruction"] assert result["task_token"] == "" assert result["answer"] == "a cat" def test_mcq_format(self): result = format_for_task("What? A) cat B) dog", "A", "mcq") assert "option's letter" in result["instruction"] def test_caption_format(self): result = format_for_task("", "A beautiful sunset", "caption") assert result["task_token"] == "" def test_detection_format(self): result = format_for_task("", "person, car", "detection") assert "objects" in result["instruction"].lower() assert result["task_token"] == "" def test_alert_format(self): result = format_for_task("", "intruder detected", "alert") assert result["task_token"] == "" def test_classify_datasets(self): assert classify_dataset_task("vqav2") == "vqa" assert classify_dataset_task("textvqa") == "ocr" assert classify_dataset_task("coco_detect") == "detection" assert classify_dataset_task("ucf_crime") == "alert" assert classify_dataset_task("llava_instruct") == "conversation" class TestUnifiedDataset(unittest.TestCase): def test_dummy_samples(self): samples = build_dummy_dataset("vqav2", num_samples=10, img_size=56) assert len(samples) == 10 assert "image" in samples[0] assert "question" in samples[0] assert "answer" in samples[0] def test_unified_dataset_no_tokenizer(self): samples = build_dummy_dataset("vqav2", num_samples=5, img_size=56) ds = UnifiedVLMDataset(samples, "vqav2", img_size=56) assert len(ds) == 5 item = ds[0] assert "image" in item assert "instruction" in item assert "answer" in item assert "task_token" in item def test_unified_dataset_with_tokenizer(self): from model.tokenizer import BPETokenizer tok = BPETokenizer(vocab_size=512) tok.train(["hello world test data question answer"] * 20) samples = build_dummy_dataset("vqav2", num_samples=5, img_size=56) ds = UnifiedVLMDataset(samples, "vqav2", tokenizer=tok, img_size=56) item = ds[0] assert "instruction_ids" in item assert "answer_ids" in item assert isinstance(item["instruction_ids"], torch.Tensor) class TestDatasetMixer(unittest.TestCase): def test_mixer_basic(self): ds_a = build_dummy_dataset("vqav2", 50, 56) ds_b = build_dummy_dataset("coco_detect", 30, 56) ua = UnifiedVLMDataset(ds_a, "vqav2", img_size=56) ub = UnifiedVLMDataset(ds_b, "coco_detect", img_size=56) mixer = DatasetMixer( datasets={"vqav2": ua, "coco_detect": ub}, weights={"vqav2": 0.7, "coco_detect": 0.3}, total_samples=100, ) assert len(mixer) == 100 item = mixer[0] assert "instruction" in item def test_mixer_all_datasets_sampled(self): ds_a = build_dummy_dataset("vqav2", 50, 56) ds_b = build_dummy_dataset("coco_detect", 50, 56) ua = UnifiedVLMDataset(ds_a, "vqav2", img_size=56) ub = UnifiedVLMDataset(ds_b, "coco_detect", img_size=56) mixer = DatasetMixer( datasets={"vqav2": ua, "coco_detect": ub}, total_samples=200, ) # Sample many items, check we get both task tokens tokens = set() for i in range(min(200, len(mixer))): tokens.add(mixer[i]["task_token"]) assert len(tokens) >= 2, f"Expected multiple task types, got {tokens}" if __name__ == "__main__": unittest.main()