| from utils.omega_parser import EPUCNNParams |
| from utils.omega_parser import DataParams, DataPreprocessing |
|
|
|
|
| def model_cfg_to_epucnn(model_cfg: EPUCNNParams): |
| return { |
| "num_classes": model_cfg.num_classes, |
| "subnetwork_name": model_cfg.subnetwork_config.architecture, |
| "num_subnetworks": model_cfg.num_subnetworks, |
| "epu_activation": model_cfg.epu_activation, |
| "subnet_activation": model_cfg.subnetwork_config.pred_activation, |
| "subnet_cfg": { |
| "input_channels": model_cfg.subnetwork_config.input_channels, |
| "base_channels": model_cfg.subnetwork_config.base_channels, |
| "fc_units": model_cfg.subnetwork_config.fc_hidden_units, |
| } |
| } |
|
|
|
|
| def data_cfg_to_dataparser(dataset_path: str, |
| images_extension: str, |
| data_mode: str, |
| preprocessing_cfg: DataPreprocessing, |
| group_by: str = None,): |
|
|
| group_by_key = None |
| group_by_value = None |
| if group_by is not None: |
| if preprocessing_cfg.data_parser in ["filename", "folder"]: |
| group_by_key = group_by |
| elif preprocessing_cfg.data_parser in ["medmnist"]: |
| group_by_value = preprocessing_cfg.label_mapping.get(group_by) |
|
|
| return { |
| "dataset_folder": dataset_path, |
| "mode": data_mode, |
| "image_ext": images_extension, |
| "label_mapping": preprocessing_cfg.label_mapping, |
| "csv_file": preprocessing_cfg.medmnist_csv_file, |
| "group_by_key": group_by_key, |
| "group_by_value": group_by_value |
| } |
|
|
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|