File size: 824 Bytes
998bb30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import os
import unittest
from pathlib import Path

from gat.datasets import build_dataset

in1k_data_root = '/root/workspace/proj/transfer-at/data/in_1k'
in1k_data_root = os.environ.get('DATA_ROOT',
                                Path(__file__).parents[2] / 'data' / 'in_1k')


class TestImageNet(unittest.TestCase):

    def test_in1k(self):
        ds = build_dataset(
            'imagenet',
            data_root=in1k_data_root,
            is_train=True,
        )
        self.assertEqual(len(ds), 1281167)

    def test_in1k_filter(self):
        ds = build_dataset(
            'imagenet',
            data_root=in1k_data_root,
            is_train=False,
            filter_class=0,
        )
        self.assertEqual(len(ds), 50)


if __name__ == '__main__':
    unittest.main()