| from .omega_parser import IUSConfig |
|
|
|
|
| class SanityChecker: |
| VALID_MODELS = ["base_one"] |
| VALID_DATA_MODES = ["grayscale", "rgb"] |
| VALID_DATA_PARSERS = ["filename", "folder", "medmnist"] |
|
|
| def __init__(self, cfg: IUSConfig): |
| self.cfg = cfg |
|
|
| def model_cfg(self) -> None: |
| architecture = self.cfg.model.subnetwork_config.architecture |
|
|
| if architecture not in self.VALID_MODELS: |
| raise ValueError( |
| f"EPU-CNN backbone implemented {self.VALID_MODELS}. Original IUS paper was using 'base_one'" |
| ) |
|
|
| num_classes = self.cfg.model.num_classes |
| epu_activation = self.cfg.model.epu_activation |
|
|
| if epu_activation == "sigmoid" and num_classes > 1: |
| raise ValueError( |
| f"epu_activation = {epu_activation} and num_classes = {num_classes}. " |
| f"For sigmoid epu_activation you have to set num_classes=1" |
| ) |
|
|
| num_subnets = self.cfg.model.num_subnetworks |
| if num_subnets != 4: |
| raise ValueError( |
| "IUS measure was implemented using 4 PFMS (either for grayscale or rgb modalities)" |
| "class PerceptualFeatureMapTransform in data.perceptual_transforms yields 4 PFM representations. " |
| ) |
|
|
| def train_cfg(self) -> None: |
| mode = self.cfg.train_params.mode |
| loss = self.cfg.train_params.loss |
|
|
| if mode == "binary" and loss != "binary_cross_entropy": |
| raise ValueError( |
| f"In train_params: mode={mode}, loss={loss}. For mode = 'binary' set loss = 'binary_cross_entropy'" |
| ) |
|
|
| def data_preprocessing(self) -> None: |
| data_mode = self.cfg.data_params.data_preprocessing.data_mode |
| channels = self.cfg.model.subnetwork_config.input_channels |
|
|
| if data_mode not in self.VALID_DATA_MODES: |
| raise ValueError( |
| f"data_mode should be one of {self.VALID_DATA_MODES}. " |
| ) |
|
|
| if channels != 1: |
| raise ValueError( |
| f"In train_params:input_channels={channels} But Perceptual Feature decomposition yields PFMs " |
| f"with 1 output channel. Please set input_channels=1" |
| ) |
|
|
| num_classes = self.cfg.model.num_classes |
| labels = set(self.cfg.data_params.data_preprocessing.label_mapping.values()) |
|
|
| if num_classes > 1 and len(labels) != num_classes: |
| raise ValueError( |
| f"In train_params: num_classes={num_classes}, in label mapping found {labels}. " |
| ) |
| if num_classes == 1 and len(labels) != 2: |
| raise ValueError( |
| f"In train_params: num_classes={num_classes}, in label mapping found {labels}. " |
| ) |
|
|
| data_parser = self.cfg.data_params.data_preprocessing.data_parser |
| if data_parser not in self.VALID_DATA_PARSERS: |
| raise ValueError( |
| f"data_parser should be one of {self.VALID_DATA_PARSERS}. " |
| f"Otherwise implement your own data_parser in data.parsers and update utils.sanity_utils" |
| ) |
|
|
| def sanity_check(self) -> None: |
| self.model_cfg() |
| self.train_cfg() |
| self.data_preprocessing() |
| print("All sanity checks passed!") |
|
|
|
|
| if __name__ == "__main__": |
| from .omega_parser import IUSConfig |
| cfg_file = "configs/train_config.yaml" |
| cfg = IUSConfig.from_yaml(cfg_file) |
|
|
| |
| |
| |
|
|
| checker = SanityChecker(cfg) |
| checker.sanity_check() |