File size: 3,744 Bytes
5a169ab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import numpy as np
import torch
from HD_BET.utils import SetNetworkToVal, softmax_helper
from abc import abstractmethod
from HD_BET.network_architecture import Network


class BaseConfig(object):
    def __init__(self):
        pass

    @abstractmethod
    def get_split(self, fold, random_state=12345):
        pass

    @abstractmethod
    def get_network(self, mode="train"):
        pass

    @abstractmethod
    def get_basic_generators(self, fold):
        pass

    @abstractmethod
    def get_data_generators(self, fold):
        pass

    def preprocess(self, data):
        return data

    def __repr__(self):
        res = ""
        for v in vars(self):
            if not v.startswith("__") and not v.startswith("_") and v != 'dataset':
                res += (v + ": " + str(self.__getattribute__(v)) + "\n")
        return res


class HD_BET_Config(BaseConfig):
    def __init__(self):
        super(HD_BET_Config, self).__init__()

        self.EXPERIMENT_NAME = self.__class__.__name__ # just a generic experiment name

        # network parameters
        self.net_base_num_layers = 21
        self.BATCH_SIZE = 2
        self.net_do_DS = True
        self.net_dropout_p = 0.0
        self.net_use_inst_norm = True
        self.net_conv_use_bias = True
        self.net_norm_use_affine = True
        self.net_leaky_relu_slope = 1e-1

        # hyperparameters
        self.INPUT_PATCH_SIZE = (128, 128, 128)
        self.num_classes = 2
        self.selected_data_channels = range(1)

        # data augmentation
        self.da_mirror_axes = (2, 3, 4)

        # validation
        self.val_use_DO = False
        self.val_use_train_mode = False # for dropout sampling
        self.val_num_repeats = 1 # only useful if dropout sampling
        self.val_batch_size = 1 # only useful if dropout sampling
        self.val_save_npz = True
        self.val_do_mirroring = True # test time data augmentation via mirroring
        self.val_write_images = True
        self.net_input_must_be_divisible_by = 16  # we could make a network class that has this as a property
        self.val_min_size = self.INPUT_PATCH_SIZE
        self.val_fn = None

        # CAREFUL! THIS IS A HACK TO MAKE PYTORCH 0.3 STATE DICTS COMPATIBLE WITH PYTORCH 0.4 (setting keep_runnings_
        # stats=True but not using them in validation. keep_runnings_stats was True before 0.3 but unused and defaults
        # to false in 0.4)
        self.val_use_moving_averages = False

    def get_network(self, train=True, pretrained_weights=None):
        net = Network(self.num_classes, len(self.selected_data_channels), self.net_base_num_layers,
                               self.net_dropout_p, softmax_helper, self.net_leaky_relu_slope, self.net_conv_use_bias,
                               self.net_norm_use_affine, True, self.net_do_DS)

        if pretrained_weights is not None:
            net.load_state_dict(
                torch.load(pretrained_weights, map_location=lambda storage, loc: storage))

        if train:
            net.train(True)
        else:
            net.train(False)
            net.apply(SetNetworkToVal(self.val_use_DO, self.val_use_moving_averages))
            net.do_ds = False

        optimizer = None
        self.lr_scheduler = None
        return net, optimizer

    def get_data_generators(self, fold):
        pass

    def get_split(self, fold, random_state=12345):
        pass

    def get_basic_generators(self, fold):
        pass

    def on_epoch_end(self, epoch):
        pass

    def preprocess(self, data):
        data = np.copy(data)
        for c in range(data.shape[0]):
            data[c] -= data[c].mean()
            data[c] /= data[c].std()
        return data


config = HD_BET_Config