File size: 26,086 Bytes
264b4c4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
from fastai.data.core import *
from fastai.learner import *
from fastai.callback.schedule import *
from fastai.torch_core import *
from fastai.callback.tracker import SaveModelCallback
# from fastai.callback.gradient import GradientClipping
from pathlib import Path
from functools import partial
import math
# from fastai.callback import GradientClipping
import torch
from fastai.tabular.core import range_of
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from fastai.callback.core import Callback
from fastai.data.core import DataLoaders
import torch.nn.functional as F
# from fastai.metrics import add_metrics
import torch.nn as nn
from fastcore.utils import ifnone
import pandas as pd
from models.base_model import ClassificationModel
from models.basicconv1d import weight_init, fcn_wang, fcn, schirrmeister, sen, basic1d
from models.inception1d import inception1d
from models.resnet1d import resnet1d18, resnet1d34, resnet1d50, resnet1d101, resnet1d152, resnet1d_wang, \
    wrn1d_22
from models.rnn1d import RNN1d
from utilities.timeseries_utils import TimeseriesDatasetCrops, ToTensor, aggregate_predictions
from models.xresnet1d import xresnet1d18_deeper, xresnet1d34_deeper, xresnet1d50_deeper, xresnet1d18_deep, \
    xresnet1d34_deep, xresnet1d50_deep, xresnet1d18, xresnet1d34, xresnet1d101, xresnet1d50, xresnet1d152
from utilities.utils import evaluate_experiment
def add_metrics(last_metrics, new_metric):
    """
    Adds a new metric to the list of last metrics.
    
    Args:
        last_metrics (list): List of previous metrics.
        new_metric (float or list): New metric(s) to add.
    
    Returns:
        list: Updated list of metrics.
    """
    if isinstance(new_metric, list):
        return last_metrics + new_metric
    else:
        return last_metrics + [new_metric]

class MetricFunc(Callback):
    """Obtains score using user-supplied function func (potentially ignoring targets with ignore_idx)"""

    def __init__(self, func, name="MetricFunc", ignore_idx=None, one_hot_encode_target=True, argmax_pred=False,
                 softmax_pred=True, flatten_target=True, sigmoid_pred=False, metric_component=None):
        super().__init__()
        self.metric_complete = self.func(self.y_true, self.y_pred)
        self.y_true = None
        self.y_pred = None
        self.func = func
        self.ignore_idx = ignore_idx
        self.one_hot_encode_target = one_hot_encode_target
        self.argmax_pred = argmax_pred
        self.softmax_pred = softmax_pred
        self.flatten_target = flatten_target
        self.sigmoid_pred = sigmoid_pred
        self.metric_component = metric_component
        self.name = name

    def on_epoch_begin(self, **kwargs):
        pass

    def on_batch_end(self, last_output, last_target, **kwargs):
        # flatten everything (to make it also work for annotation tasks)
        y_pred_flat = last_output.view((-1, last_output.size()[-1]))

        if self.flatten_target:
            last_target.view(-1)
        y_true_flat = last_target

        # optionally take argmax of predictions
        if self.argmax_pred is True:
            y_pred_flat = y_pred_flat.argmax(dim=1)
        elif self.softmax_pred is True:
            y_pred_flat = F.softmax(y_pred_flat, dim=1)
        elif self.sigmoid_pred is True:
            y_pred_flat = torch.sigmoid(y_pred_flat)

        # potentially remove ignore_idx entries
        if self.ignore_idx is not None:
            selected_indices = (y_true_flat != self.ignore_idx).nonzero().squeeze()
            y_pred_flat = y_pred_flat[selected_indices]
            y_true_flat = y_true_flat[selected_indices]

        y_pred_flat = to_np(y_pred_flat)
        y_true_flat = to_np(y_true_flat)

        if self.one_hot_encode_target is True:
            y_true_flat = np.one_hot_np(y_true_flat, last_output.size()[-1])

        if self.y_pred is None:
            self.y_pred = y_pred_flat
            self.y_true = y_true_flat
        else:
            self.y_pred = np.concatenate([self.y_pred, y_pred_flat], axis=0)
            self.y_true = np.concatenate([self.y_true, y_true_flat], axis=0)

    def on_epoch_end(self, last_metrics, **kwargs):
        # access full metric (possibly multiple components) via self.metric_complete
        if self.metric_component is not None:
            return add_metrics(last_metrics, self.metric_complete[self.metric_component])
        else:
            return add_metrics(last_metrics, self.metric_complete)


def fmax_metric(targs, preds):
    return evaluate_experiment(targs, preds)["Fmax"]


def auc_metric(targs, preds):
    return evaluate_experiment(targs, preds)["macro_auc"]


def mse_flat(preds, targs):
    return torch.mean(torch.pow(preds.view(-1) - targs.view(-1), 2))


def nll_regression(preds, targs):
    # preds: bs, 2
    # targs: bs, 1
    preds_mean = preds[:, 0]
    # warning: output goes through exponential map to ensure positivity
    preds_var = torch.clamp(torch.exp(preds[:, 1]), 1e-4, 1e10)
    # print(to_np(preds_mean)[0],to_np(targs)[0,0],to_np(torch.sqrt(preds_var))[0])
    return torch.mean(torch.log(2 * math.pi * preds_var) / 2) + torch.mean(
        torch.pow(preds_mean - targs[:, 0], 2) / 2 / preds_var)


def nll_regression_init(m):
    assert (isinstance(m, nn.Linear))
    nn.init.normal_(m.weight, 0., 0.001)
    nn.init.constant_(m.bias, 4)


def lr_find_plot(learner, path, filename="lr_find", n_skip=10, n_skip_end=2):
    """
    saves lr_find plot as file (normally only jupyter output)
    on the x-axis is lrs[-1]
    """
    learner.lr_find()

    backend_old = matplotlib.get_backend()
    plt.switch_backend('agg')
    plt.ylabel("loss")
    plt.xlabel("learning rate (log scale)")
    losses = [to_np(x) for x in learner.recorder.losses[n_skip:-(n_skip_end + 1)]]
    # print(learner.recorder.val_losses)
    # val_losses = [ to_np(x) for x in learner.recorder.val_losses[n_skip:-(n_skip_end+1)]]

    plt.plot(learner.recorder.lrs[n_skip:-(n_skip_end + 1)], losses)
    # plt.plot(learner.recorder.lrs[n_skip:-(n_skip_end+1)],val_losses )

    plt.xscale('log')
    plt.savefig(str(path / (filename + '.png')))
    plt.switch_backend(backend_old)


def losses_plot(learner, path, filename="losses", last: int = None):
    """
    saves lr_find plot as file (normally only jupyter output)
    on the x-axis is lrs[-1]
    """
    backend_old = matplotlib.get_backend()
    plt.switch_backend('agg')
    plt.ylabel("loss")
    plt.xlabel("Batches processed")

    last = ifnone(last, len(learner.recorder.nb_batches))
    l_b = np.sum(learner.recorder.nb_batches[-last:])
    iterations = range_of(learner.recorder.losses)[-l_b:]
    plt.plot(iterations, learner.recorder.losses[-l_b:], label='Train')
    val_iter = learner.recorder.nb_batches[-last:]
    val_iter = np.cumsum(val_iter) + np.sum(learner.recorder.nb_batches[:-last])
    plt.plot(val_iter, learner.recorder.val_losses[-last:], label='Validation')
    plt.legend()

    plt.savefig(str(path / (filename + '.png')))
    plt.switch_backend(backend_old)


class FastaiModel(ClassificationModel):
    def __init__(self, name, n_classes, freq, output_folder, input_shape, pretrained=False, input_size=2.5,
                 input_channels=12, chunkify_train=False, chunkify_valid=True, bs=128, ps_head=0.5, lin_ftrs_head=None,
                 wd=1e-2, epochs=50, lr=1e-2, kernel_size=5, loss="binary_cross_entropy", pretrained_folder=None,
                 n_classes_pretrained=None, gradual_unfreezing=True, discriminative_lrs=True, epochs_finetuning=30,
                 early_stopping=None, aggregate_fn="max", concat_train_val=False):
        super().__init__()

        if lin_ftrs_head is None:
            lin_ftrs_head = [128]
        self.name = name
        self.num_classes = n_classes if loss != "nll_regression" else 2
        self.target_fs = freq
        self.output_folder = Path(output_folder)

        self.input_size = int(input_size * self.target_fs)
        self.input_channels = input_channels

        self.chunkify_train = chunkify_train
        self.chunkify_valid = chunkify_valid

        self.chunk_length_train = 2 * self.input_size  # target_fs*6
        self.chunk_length_valid = self.input_size

        self.min_chunk_length = self.input_size  # chunk_length

        self.stride_length_train = self.input_size  # chunk_length_train//8
        self.stride_length_valid = self.input_size // 2  # chunk_length_valid

        self.copies_valid = 0  # >0 should only be used with chunkify_valid=False

        self.bs = bs
        self.ps_head = ps_head
        self.lin_ftrs_head = lin_ftrs_head
        self.wd = wd
        self.epochs = epochs
        self.lr = lr
        self.kernel_size = kernel_size
        self.loss = loss
        self.input_shape = input_shape

        if pretrained:
            if pretrained_folder is None:
                pretrained_folder = Path('../output/exp0/models/' + name.split("_pretrained")[0] + '/')
                # pretrained_folder = Path('/output/exp0/models/'+name.split("_pretrained")[0]+'/')

            if n_classes_pretrained is None:
                n_classes_pretrained = 71

        self.pretrained_folder = None if pretrained_folder is None else Path(pretrained_folder)
        self.n_classes_pretrained = n_classes_pretrained
        self.discriminative_lrs = discriminative_lrs
        self.gradual_unfreezing = gradual_unfreezing
        self.epochs_finetuning = epochs_finetuning

        self.early_stopping = early_stopping
        self.aggregate_fn = aggregate_fn
        self.concat_train_val = concat_train_val

    def fit(self, X_train, y_train, X_val, y_val):
        # convert everything to float32
        X_train = [l.astype(np.float32) for l in X_train]
        X_val = [l.astype(np.float32) for l in X_val]
        y_train = [l.astype(np.float32) for l in y_train]
        y_val = [l.astype(np.float32) for l in y_val]

        if self.concat_train_val:
            X_train += X_val
            y_train += y_val

        if self.pretrained_folder is None:  # from scratch
            print("Training from scratch...")
            learn = self._get_learner(X_train, y_train, X_val, y_val)

            # if(self.discriminative_lrs):
            #    layer_groups=learn.model.get_layer_groups()
            #    learn.split(layer_groups)
            learn.model.apply(weight_init)

            # initialization for regression output
            if self.loss == "nll_regression" or self.loss == "mse":
                output_layer_new = learn.model.get_output_layer()
                output_layer_new.apply(nll_regression_init)
                learn.model.set_output_layer(output_layer_new)

            lr_find_plot(learn, self.output_folder)
            learn.fit_one_cycle(self.epochs, self.lr)  # slice(self.lr) if self.discriminative_lrs else self.lr)
            losses_plot(learn, self.output_folder)
        else:  # finetuning
            print("Finetuning...")
            # create learner
            learn = self._get_learner(X_train, y_train, X_val, y_val, self.n_classes_pretrained)

            # load pretrained model
            learn.path = self.pretrained_folder
            learn.load(self.pretrained_folder.stem)
            learn.path = self.output_folder

            # exchange top layer
            output_layer = learn.model.get_output_layer()
            output_layer_new = nn.Linear(output_layer.in_features, self.num_classes).cuda()
            apply_init(output_layer_new, nn.init.kaiming_normal_)
            learn.model.set_output_layer(output_layer_new)

            # layer groups
            if self.discriminative_lrs:
                layer_groups = learn.model.get_layer_groups()
                learn.split(layer_groups)

            learn.train_bn = True  # make sure if bn mode is train

            # train
            lr = self.lr
            if self.gradual_unfreezing:
                assert (self.discriminative_lrs is True)
                learn.freeze()
                lr_find_plot(learn, self.output_folder, "lr_find0")
                learn.fit_one_cycle(self.epochs_finetuning, lr)
                losses_plot(learn, self.output_folder, "losses0")
                # for n in [0]:#range(len(layer_groups)): learn.freeze_to(-n-1) lr_find_plot(learn,
                # self.output_folder,"lr_find"+str(n)) learn.fit_one_cycle(self.epochs_gradual_unfreezing,slice(lr))
                # losses_plot(learn, self.output_folder,"losses"+str(n)) if(n==0):#reduce lr after first step lr/=10.
                # if(n>0 and (self.name.startswith("fastai_lstm") or self.name.startswith("fastai_gru"))):#reduce lr
                # further for RNNs lr/=10

            learn.unfreeze()
            lr_find_plot(learn, self.output_folder, "lr_find" + str(len(layer_groups)))
            learn.fit_one_cycle(self.epochs_finetuning, slice(lr / 1000, lr / 10))
            losses_plot(learn, self.output_folder, "losses" + str(len(layer_groups)))

        learn.save(self.name)  # even for early stopping the best model will have been loaded again

    def predict(self, X):
        X = [l.astype(np.float32) for l in X]
        y_dummy = [np.ones(self.num_classes, dtype=np.float32) for _ in range(len(X))]

        learn = self._get_learner(X, y_dummy, X, y_dummy)
        learn.load(self.name)

        preds, targs = learn.get_preds()
        preds = to_np(preds)

        idmap = learn.data.valid_ds.get_id_mapping()

        return aggregate_predictions(preds, idmap=idmap,
                                     aggregate_fn=np.mean if self.aggregate_fn == "mean" else np.amax)

    def _get_learner(self, X_train, y_train, X_val, y_val, num_classes=None):
        df_train = pd.DataFrame({"data": range(len(X_train)), "label": y_train})
        df_valid = pd.DataFrame({"data": range(len(X_val)), "label": y_val})

        tfms_ptb_xl = [ToTensor()]

        ds_train = TimeseriesDatasetCrops(df_train, self.input_size, num_classes=self.num_classes,
                                          chunk_length=self.chunk_length_train if self.chunkify_train else 0,
                                          min_chunk_length=self.min_chunk_length,
                                          stride=self.stride_length_train, transforms=tfms_ptb_xl,
                                          annotation=False, col_lbl="label", npy_data=X_train)
        ds_valid = TimeseriesDatasetCrops(df_valid, self.input_size, num_classes=self.num_classes,
                                          chunk_length=self.chunk_length_valid if self.chunkify_valid else 0,
                                          min_chunk_length=self.min_chunk_length,
                                          stride=self.stride_length_valid, transforms=tfms_ptb_xl,
                                          annotation=False, col_lbl="label", npy_data=X_val)

        db = DataLoaders(ds_train, ds_valid)

        if self.loss == "binary_cross_entropy":
            loss = F.binary_cross_entropy_with_logits
        elif self.loss == "cross_entropy":
            loss = F.cross_entropy
        elif self.loss == "mse":
            loss = mse_flat
        elif self.loss == "nll_regression":
            loss = nll_regression
        else:
            print("loss not found")
            assert (True)

        self.input_channels = self.input_shape[-1]
        metrics = []

        print("model:", self.name)
        # note: all models of a particular kind share the same prefix but potentially a different
        # postfix such as _input256
        num_classes = self.num_classes if num_classes is None else num_classes
        # resnet resnet1d18,resnet1d34,resnet1d50,resnet1d101,resnet1d152,resnet1d_wang,resnet1d,wrn1d_22
        if self.name.startswith("fastai_resnet1d18"):
            model = resnet1d18(num_classes=num_classes, input_channels=self.input_channels, inplanes=128,
                               kernel_size=self.kernel_size, ps_head=self.ps_head,
                               lin_ftrs_head=self.lin_ftrs_head)
        elif self.name.startswith("fastai_resnet1d34"):
            model = resnet1d34(num_classes=num_classes, input_channels=self.input_channels, inplanes=128,
                               kernel_size=self.kernel_size, ps_head=self.ps_head,
                               lin_ftrs_head=self.lin_ftrs_head)
        elif self.name.startswith("fastai_resnet1d50"):
            model = resnet1d50(num_classes=num_classes, input_channels=self.input_channels, inplanes=128,
                               kernel_size=self.kernel_size, ps_head=self.ps_head,
                               lin_ftrs_head=self.lin_ftrs_head)
        elif self.name.startswith("fastai_resnet1d101"):
            model = resnet1d101(num_classes=num_classes, input_channels=self.input_channels, inplanes=128,
                                kernel_size=self.kernel_size, ps_head=self.ps_head,
                                lin_ftrs_head=self.lin_ftrs_head)
        elif self.name.startswith("fastai_resnet1d152"):
            model = resnet1d152(num_classes=num_classes, input_channels=self.input_channels, inplanes=128,
                                kernel_size=self.kernel_size, ps_head=self.ps_head,
                                lin_ftrs_head=self.lin_ftrs_head)
        elif self.name.startswith("fastai_resnet1d_wang"):
            model = resnet1d_wang(num_classes=num_classes, input_channels=self.input_channels,
                                  kernel_size=self.kernel_size, ps_head=self.ps_head,
                                  lin_ftrs_head=self.lin_ftrs_head)
        elif self.name.startswith("fastai_wrn1d_22"):
            model = wrn1d_22(num_classes=num_classes, input_channels=self.input_channels,
                             kernel_size=self.kernel_size, ps_head=self.ps_head,
                             lin_ftrs_head=self.lin_ftrs_head)

        # xresnet ... (order important for string capture)
        elif self.name.startswith("fastai_xresnet1d18_deeper"):
            model = xresnet1d18_deeper(num_classes=num_classes, input_channels=self.input_channels,
                                       kernel_size=self.kernel_size, ps_head=self.ps_head,
                                       lin_ftrs_head=self.lin_ftrs_head)
        elif self.name.startswith("fastai_xresnet1d34_deeper"):
            model = xresnet1d34_deeper(num_classes=num_classes, input_channels=self.input_channels,
                                       kernel_size=self.kernel_size, ps_head=self.ps_head,
                                       lin_ftrs_head=self.lin_ftrs_head)
        elif self.name.startswith("fastai_xresnet1d50_deeper"):
            model = xresnet1d50_deeper(num_classes=num_classes, input_channels=self.input_channels,
                                       kernel_size=self.kernel_size, ps_head=self.ps_head,
                                       lin_ftrs_head=self.lin_ftrs_head)
        elif self.name.startswith("fastai_xresnet1d18_deep"):
            model = xresnet1d18_deep(num_classes=num_classes, input_channels=self.input_channels,
                                     kernel_size=self.kernel_size, ps_head=self.ps_head,
                                     lin_ftrs_head=self.lin_ftrs_head)
        elif self.name.startswith("fastai_xresnet1d34_deep"):
            model = xresnet1d34_deep(num_classes=num_classes, input_channels=self.input_channels,
                                     kernel_size=self.kernel_size, ps_head=self.ps_head,
                                     lin_ftrs_head=self.lin_ftrs_head)
        elif self.name.startswith("fastai_xresnet1d50_deep"):
            model = xresnet1d50_deep(num_classes=num_classes, input_channels=self.input_channels,
                                     kernel_size=self.kernel_size, ps_head=self.ps_head,
                                     lin_ftrs_head=self.lin_ftrs_head)
        elif self.name.startswith("fastai_xresnet1d18"):
            model = xresnet1d18(num_classes=num_classes, input_channels=self.input_channels,
                                kernel_size=self.kernel_size, ps_head=self.ps_head, lin_ftrs_head=self.lin_ftrs_head)
        elif self.name.startswith("fastai_xresnet1d34"):
            model = xresnet1d34(num_classes=num_classes, input_channels=self.input_channels,
                                kernel_size=self.kernel_size, ps_head=self.ps_head, lin_ftrs_head=self.lin_ftrs_head)
        elif self.name.startswith("fastai_xresnet1d50"):
            model = xresnet1d50(num_classes=num_classes, input_channels=self.input_channels,
                                kernel_size=self.kernel_size, ps_head=self.ps_head, lin_ftrs_head=self.lin_ftrs_head)
        elif self.name.startswith("fastai_xresnet1d101"):
            model = xresnet1d101(num_classes=num_classes, input_channels=self.input_channels,
                                 kernel_size=self.kernel_size, ps_head=self.ps_head, lin_ftrs_head=self.lin_ftrs_head)
        elif self.name.startswith("fastai_xresnet1d152"):
            model = xresnet1d152(num_classes=num_classes, input_channels=self.input_channels,
                                 kernel_size=self.kernel_size, ps_head=self.ps_head, lin_ftrs_head=self.lin_ftrs_head)

        # inception passing the default kernel size of 5 leads to a max kernel size of 40-1 in the inception model as
        # proposed in the original paper
        elif self.name == "fastai_inception1d_no_residual":  # note: order important for string capture
            model = inception1d(num_classes=num_classes, input_channels=self.input_channels,
                                use_residual=False, ps_head=self.ps_head, lin_ftrs_head=self.lin_ftrs_head,
                                kernel_size=8 * self.kernel_size)
        elif self.name.startswith("fastai_inception1d"):
            model = inception1d(num_classes=num_classes, input_channels=self.input_channels,
                                use_residual=True, ps_head=self.ps_head, lin_ftrs_head=self.lin_ftrs_head,
                                kernel_size=8 * self.kernel_size)


        # BasicConv1d fcn,fcn_wang,schirrmeister,sen,basic1d
        elif self.name.startswith("fastai_fcn_wang"):  # note: order important for string capture
            model = fcn_wang(num_classes=num_classes, input_channels=self.input_channels,
                             ps_head=self.ps_head, lin_ftrs_head=self.lin_ftrs_head)
        elif self.name.startswith("fastai_fcn"):
            model = fcn(num_classes=num_classes, input_channels=self.input_channels)
        elif self.name.startswith("fastai_schirrmeister"):
            model = schirrmeister(num_classes=num_classes, input_channels=self.input_channels,
                                  ps_head=self.ps_head, lin_ftrs_head=self.lin_ftrs_head)
        elif self.name.startswith("fastai_sen"):
            model = sen(num_classes=num_classes, input_channels=self.input_channels, ps_head=self.ps_head,
                        lin_ftrs_head=self.lin_ftrs_head)
        elif self.name.startswith("fastai_basic1d"):
            model = basic1d(num_classes=num_classes, input_channels=self.input_channels,
                            kernel_size=self.kernel_size, ps_head=self.ps_head,
                            lin_ftrs_head=self.lin_ftrs_head)
        # RNN
        elif self.name.startswith("fastai_lstm_bidir"):
            model = RNN1d(input_channels=self.input_channels, num_classes=num_classes, lstm=True,
                          bidirectional=True, ps_head=self.ps_head, lin_ftrs_head=self.lin_ftrs_head)
        elif self.name.startswith("fastai_gru_bidir"):
            model = RNN1d(input_channels=self.input_channels, num_classes=num_classes, lstm=False,
                          bidirectional=True, ps_head=self.ps_head, lin_ftrs_head=self.lin_ftrs_head)
        elif self.name.startswith("fastai_lstm"):
            model = RNN1d(input_channels=self.input_channels, num_classes=num_classes, lstm=True,
                          bidirectional=False, ps_head=self.ps_head, lin_ftrs_head=self.lin_ftrs_head)
        elif self.name.startswith("fastai_gru"):
            model = RNN1d(input_channels=self.input_channels, num_classes=num_classes, lstm=False,
                          bidirectional=False, ps_head=self.ps_head, lin_ftrs_head=self.lin_ftrs_head)
        else:
            print("Model not found.")
            assert True

        learn = Learner(db, model, loss_func=loss, metrics=metrics, wd=self.wd, path=self.output_folder)

        if self.name.startswith("fastai_lstm") or self.name.startswith("fastai_gru"):
            learn.callback_fns.append(partial(GradientClipping, clip=0.25))

        if self.early_stopping is not None:
            # supported options: valid_loss, macro_auc, fmax
            if self.early_stopping == "macro_auc" and self.loss != "mse" and self.loss != "nll_regression":
                metric = MetricFunc(auc_metric, self.early_stopping,
                                    one_hot_encode_target=False, argmax_pred=False, softmax_pred=False,
                                    sigmoid_pred=True, flatten_target=False)
                learn.metrics.append(metric)
                learn.callback_fns.append(
                    partial(SaveModelCallback, monitor=self.early_stopping, every='improvement', name=self.name))
            elif self.early_stopping == "fmax" and self.loss != "mse" and self.loss != "nll_regression":
                metric = MetricFunc(fmax_metric, self.early_stopping,
                                    one_hot_encode_target=False, argmax_pred=False, softmax_pred=False,
                                    sigmoid_pred=True, flatten_target=False)
                learn.metrics.append(metric)
                learn.callback_fns.append(partial(SaveModelCallback, monitor=self.early_stopping, every='improvement', name=self.name))
            elif self.early_stopping == "valid_loss":
                learn.callback_fns.append(partial(SaveModelCallback, monitor=self.early_stopping, every='improvement', name=self.name))

        return learn