OmniCoreX / tests /test_data_loader.py
Kosasih's picture
Rename test/test_data_loader.py to tests/test_data_loader.py
1830953 verified
import unittest
import tempfile
import json
from data_loader import OmniCoreXMultiModalDataset, create_omncorex_dataloader
def dummy_tokenizer(text):
return [ord(c) % 50 + 1 for c in text]
class DataLoaderTest(unittest.TestCase):
def setUp(self):
self.samples = [
{"text": "Hello", "sensor": [0.1, 0.2, 0.3], "image": ""},
{"text": "World", "sensor": [0.4, 0.5], "image": ""}
]
self.tmp_file = tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.json')
json.dump(self.samples, self.tmp_file)
self.tmp_file.close()
def tearDown(self):
import os
os.unlink(self.tmp_file.name)
def test_dataset_length(self):
ds = OmniCoreXMultiModalDataset(self.tmp_file.name, ["text", "image", "sensor"], tokenizer=dummy_tokenizer)
self.assertEqual(len(ds), 2)
def test_dataloader_batch(self):
dl = create_omncorex_dataloader(self.tmp_file.name, ["text", "image", "sensor"], tokenizer=dummy_tokenizer, batch_size=2)
batch = next(iter(dl))
self.assertIn("text", batch)
self.assertIn("sensor", batch)
self.assertIn("image", batch)
if __name__ == "__main__":
unittest.main()