| | import numpy as np |
| | import torch |
| | from HD_BET.utils import SetNetworkToVal, softmax_helper |
| | from abc import abstractmethod |
| | from HD_BET.network_architecture import Network |
| |
|
| |
|
| | class BaseConfig(object): |
| | def __init__(self): |
| | pass |
| |
|
| | @abstractmethod |
| | def get_split(self, fold, random_state=12345): |
| | pass |
| |
|
| | @abstractmethod |
| | def get_network(self, mode="train"): |
| | pass |
| |
|
| | @abstractmethod |
| | def get_basic_generators(self, fold): |
| | pass |
| |
|
| | @abstractmethod |
| | def get_data_generators(self, fold): |
| | pass |
| |
|
| | def preprocess(self, data): |
| | return data |
| |
|
| | def __repr__(self): |
| | res = "" |
| | for v in vars(self): |
| | if not v.startswith("__") and not v.startswith("_") and v != 'dataset': |
| | res += (v + ": " + str(self.__getattribute__(v)) + "\n") |
| | return res |
| |
|
| |
|
| | class HD_BET_Config(BaseConfig): |
| | def __init__(self): |
| | super(HD_BET_Config, self).__init__() |
| |
|
| | self.EXPERIMENT_NAME = self.__class__.__name__ |
| |
|
| | |
| | self.net_base_num_layers = 21 |
| | self.BATCH_SIZE = 2 |
| | self.net_do_DS = True |
| | self.net_dropout_p = 0.0 |
| | self.net_use_inst_norm = True |
| | self.net_conv_use_bias = True |
| | self.net_norm_use_affine = True |
| | self.net_leaky_relu_slope = 1e-1 |
| |
|
| | |
| | self.INPUT_PATCH_SIZE = (128, 128, 128) |
| | self.num_classes = 2 |
| | self.selected_data_channels = range(1) |
| |
|
| | |
| | self.da_mirror_axes = (2, 3, 4) |
| |
|
| | |
| | self.val_use_DO = False |
| | self.val_use_train_mode = False |
| | self.val_num_repeats = 1 |
| | self.val_batch_size = 1 |
| | self.val_save_npz = True |
| | self.val_do_mirroring = True |
| | self.val_write_images = True |
| | self.net_input_must_be_divisible_by = 16 |
| | self.val_min_size = self.INPUT_PATCH_SIZE |
| | self.val_fn = None |
| |
|
| | |
| | |
| | |
| | self.val_use_moving_averages = False |
| |
|
| | def get_network(self, train=True, pretrained_weights=None): |
| | net = Network(self.num_classes, len(self.selected_data_channels), self.net_base_num_layers, |
| | self.net_dropout_p, softmax_helper, self.net_leaky_relu_slope, self.net_conv_use_bias, |
| | self.net_norm_use_affine, True, self.net_do_DS) |
| |
|
| | if pretrained_weights is not None: |
| | net.load_state_dict( |
| | torch.load(pretrained_weights, map_location=lambda storage, loc: storage)) |
| |
|
| | if train: |
| | net.train(True) |
| | else: |
| | net.train(False) |
| | net.apply(SetNetworkToVal(self.val_use_DO, self.val_use_moving_averages)) |
| | net.do_ds = False |
| |
|
| | optimizer = None |
| | self.lr_scheduler = None |
| | return net, optimizer |
| |
|
| | def get_data_generators(self, fold): |
| | pass |
| |
|
| | def get_split(self, fold, random_state=12345): |
| | pass |
| |
|
| | def get_basic_generators(self, fold): |
| | pass |
| |
|
| | def on_epoch_end(self, epoch): |
| | pass |
| |
|
| | def preprocess(self, data): |
| | data = np.copy(data) |
| | for c in range(data.shape[0]): |
| | data[c] -= data[c].mean() |
| | data[c] /= data[c].std() |
| | return data |
| |
|
| |
|
| | config = HD_BET_Config |
| |
|
| |
|