File size: 2,128 Bytes
99ec8a2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
from utils.omega_parser import EPUCNNParams
from utils.omega_parser import DataParams, DataPreprocessing


def model_cfg_to_epucnn(model_cfg: EPUCNNParams):
    return {
        "num_classes": model_cfg.num_classes,
        "subnetwork_name": model_cfg.subnetwork_config.architecture,
        "num_subnetworks": model_cfg.num_subnetworks,
        "epu_activation":  model_cfg.epu_activation,
        "subnet_activation": model_cfg.subnetwork_config.pred_activation,
        "subnet_cfg": {
            "input_channels": model_cfg.subnetwork_config.input_channels,
            "base_channels": model_cfg.subnetwork_config.base_channels,
            "fc_units": model_cfg.subnetwork_config.fc_hidden_units,
        }
    }


def data_cfg_to_dataparser(dataset_path: str,
                           images_extension: str,
                           data_mode: str,
                           preprocessing_cfg: DataPreprocessing,
                           group_by: str = None,):

    group_by_key = None
    group_by_value = None
    if group_by is not None:
        if preprocessing_cfg.data_parser in ["filename", "folder"]:
            group_by_key = group_by                                             # eg "normal"
        elif preprocessing_cfg.data_parser in ["medmnist"]:
            group_by_value = preprocessing_cfg.label_mapping.get(group_by)      # eg "0" from item ("normal": "0")

    return {
        "dataset_folder": dataset_path,
        "mode": data_mode,
        "image_ext": images_extension,
        "label_mapping": preprocessing_cfg.label_mapping,
        "csv_file": preprocessing_cfg.medmnist_csv_file,
        "group_by_key": group_by_key,
        "group_by_value": group_by_value
    }


# def data_cfg_to_dataparser(data_cfg: DataParams, data_mode: str = 'train'):
#     return {
#         "dataset_folder": data_cfg.dataset_path,
#         "mode": data_mode,
#         "image_ext": data_cfg.images_extension,
#         "kwargs": {
#             "label_mapping": data_cfg.data_preprocessing.label_mapping,
#             "csv_file": data_cfg.data_preprocessing.medmnist_csv_file
#         }
#     }