ius / utils /sanity_utils.py
pgatoula's picture
Sync from GitHub via hub-sync
99ec8a2 verified
from .omega_parser import IUSConfig
class SanityChecker:
VALID_MODELS = ["base_one"]
VALID_DATA_MODES = ["grayscale", "rgb"]
VALID_DATA_PARSERS = ["filename", "folder", "medmnist"]
def __init__(self, cfg: IUSConfig):
self.cfg = cfg
def model_cfg(self) -> None:
architecture = self.cfg.model.subnetwork_config.architecture
if architecture not in self.VALID_MODELS:
raise ValueError(
f"EPU-CNN backbone implemented {self.VALID_MODELS}. Original IUS paper was using 'base_one'"
)
num_classes = self.cfg.model.num_classes
epu_activation = self.cfg.model.epu_activation
if epu_activation == "sigmoid" and num_classes > 1:
raise ValueError(
f"epu_activation = {epu_activation} and num_classes = {num_classes}. "
f"For sigmoid epu_activation you have to set num_classes=1"
)
num_subnets = self.cfg.model.num_subnetworks
if num_subnets != 4:
raise ValueError(
"IUS measure was implemented using 4 PFMS (either for grayscale or rgb modalities)"
"class PerceptualFeatureMapTransform in data.perceptual_transforms yields 4 PFM representations. "
)
def train_cfg(self) -> None:
mode = self.cfg.train_params.mode
loss = self.cfg.train_params.loss
if mode == "binary" and loss != "binary_cross_entropy":
raise ValueError(
f"In train_params: mode={mode}, loss={loss}. For mode = 'binary' set loss = 'binary_cross_entropy'"
)
def data_preprocessing(self) -> None:
data_mode = self.cfg.data_params.data_preprocessing.data_mode
channels = self.cfg.model.subnetwork_config.input_channels
if data_mode not in self.VALID_DATA_MODES:
raise ValueError(
f"data_mode should be one of {self.VALID_DATA_MODES}. "
)
if channels != 1:
raise ValueError(
f"In train_params:input_channels={channels} But Perceptual Feature decomposition yields PFMs "
f"with 1 output channel. Please set input_channels=1"
)
num_classes = self.cfg.model.num_classes
labels = set(self.cfg.data_params.data_preprocessing.label_mapping.values())
if num_classes > 1 and len(labels) != num_classes:
raise ValueError(
f"In train_params: num_classes={num_classes}, in label mapping found {labels}. "
)
if num_classes == 1 and len(labels) != 2:
raise ValueError(
f"In train_params: num_classes={num_classes}, in label mapping found {labels}. "
)
data_parser = self.cfg.data_params.data_preprocessing.data_parser
if data_parser not in self.VALID_DATA_PARSERS:
raise ValueError(
f"data_parser should be one of {self.VALID_DATA_PARSERS}. "
f"Otherwise implement your own data_parser in data.parsers and update utils.sanity_utils"
)
def sanity_check(self) -> None:
self.model_cfg()
self.train_cfg()
self.data_preprocessing()
print("All sanity checks passed!")
if __name__ == "__main__":
from .omega_parser import IUSConfig
cfg_file = "configs/train_config.yaml"
cfg = IUSConfig.from_yaml(cfg_file)
# print(cfg.model)
# print(cfg.train_params)
# print(cfg.data_params)
checker = SanityChecker(cfg)
checker.sanity_check()