import os import unittest from pathlib import Path from gat.datasets import build_dataset in1k_data_root = '/root/workspace/proj/transfer-at/data/in_1k' in1k_data_root = os.environ.get('DATA_ROOT', Path(__file__).parents[2] / 'data' / 'in_1k') class TestImageNet(unittest.TestCase): def test_in1k(self): ds = build_dataset( 'imagenet', data_root=in1k_data_root, is_train=True, ) self.assertEqual(len(ds), 1281167) def test_in1k_filter(self): ds = build_dataset( 'imagenet', data_root=in1k_data_root, is_train=False, filter_class=0, ) self.assertEqual(len(ds), 50) if __name__ == '__main__': unittest.main()