File size: 1,228 Bytes
7a6e777 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 |
import unittest
import tempfile
import json
from data_loader import OmniCoreXMultiModalDataset, create_omncorex_dataloader
def dummy_tokenizer(text):
return [ord(c) % 50 + 1 for c in text]
class DataLoaderTest(unittest.TestCase):
def setUp(self):
self.samples = [
{"text": "Hello", "sensor": [0.1, 0.2, 0.3], "image": ""},
{"text": "World", "sensor": [0.4, 0.5], "image": ""}
]
self.tmp_file = tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.json')
json.dump(self.samples, self.tmp_file)
self.tmp_file.close()
def tearDown(self):
import os
os.unlink(self.tmp_file.name)
def test_dataset_length(self):
ds = OmniCoreXMultiModalDataset(self.tmp_file.name, ["text", "image", "sensor"], tokenizer=dummy_tokenizer)
self.assertEqual(len(ds), 2)
def test_dataloader_batch(self):
dl = create_omncorex_dataloader(self.tmp_file.name, ["text", "image", "sensor"], tokenizer=dummy_tokenizer, batch_size=2)
batch = next(iter(dl))
self.assertIn("text", batch)
self.assertIn("sensor", batch)
self.assertIn("image", batch)
if __name__ == "__main__":
unittest.main()
|