SAE / attacks /AIM /tests /test_datasets /test_datasets.py
Ttius's picture
Upload 192 files
998bb30 verified
import os
import unittest
from pathlib import Path
from gat.datasets import build_dataset
in1k_data_root = '/root/workspace/proj/transfer-at/data/in_1k'
in1k_data_root = os.environ.get('DATA_ROOT',
Path(__file__).parents[2] / 'data' / 'in_1k')
class TestImageNet(unittest.TestCase):
def test_in1k(self):
ds = build_dataset(
'imagenet',
data_root=in1k_data_root,
is_train=True,
)
self.assertEqual(len(ds), 1281167)
def test_in1k_filter(self):
ds = build_dataset(
'imagenet',
data_root=in1k_data_root,
is_train=False,
filter_class=0,
)
self.assertEqual(len(ds), 50)
if __name__ == '__main__':
unittest.main()