File size: 1,321 Bytes
b84549f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38

from typing import Dict, Type

from .ab_dataset import ABDataset


static_dataset_registery = {}

# (NIPS'20) Measuring Robustness to Natural Distribution Shifts in Image Classification
POSSIBLE_SHIFT_TYPES = ['Consistency Shifts', 'Dataset Shifts', 'Adversarially Filtered Shifts',
               'Image Corruptions', 'Geometric Transformations', 'Style Transfer', 'Adversarial Examples']

def dataset_register(name, classes, task_type, object_type, class_aliases=[], shift_type: Dict[str, str]=None):
    assert shift_type is None or len(shift_type.keys()) == 1
    if shift_type is not None:
        assert list(shift_type.values())[0] in POSSIBLE_SHIFT_TYPES
            
    class _Register:
        def __init__(self, func: Type[ABDataset]):
            self.func = func
            static_dataset_registery[name] = (self, classes, task_type, object_type, class_aliases, shift_type)

        def __call__(self, *args, **kwargs):
            res = self.func(*args, **kwargs)
            
            res.name = name
            res.classes = classes
            res.class_aliases = class_aliases
            res.shift_type = shift_type
            res.task_type = task_type
            res.object_type = object_type
            
            res.build()
            
            return res
    
    return _Register