| import os | |
| import unittest | |
| from pathlib import Path | |
| from gat.datasets import build_dataset | |
| in1k_data_root = '/root/workspace/proj/transfer-at/data/in_1k' | |
| in1k_data_root = os.environ.get('DATA_ROOT', | |
| Path(__file__).parents[2] / 'data' / 'in_1k') | |
| class TestImageNet(unittest.TestCase): | |
| def test_in1k(self): | |
| ds = build_dataset( | |
| 'imagenet', | |
| data_root=in1k_data_root, | |
| is_train=True, | |
| ) | |
| self.assertEqual(len(ds), 1281167) | |
| def test_in1k_filter(self): | |
| ds = build_dataset( | |
| 'imagenet', | |
| data_root=in1k_data_root, | |
| is_train=False, | |
| filter_class=0, | |
| ) | |
| self.assertEqual(len(ds), 50) | |
| if __name__ == '__main__': | |
| unittest.main() | |