from fastai.data.core import * from fastai.learner import * from fastai.callback.schedule import * from fastai.torch_core import * from fastai.callback.tracker import SaveModelCallback # from fastai.callback.gradient import GradientClipping from pathlib import Path from functools import partial import math # from fastai.callback import GradientClipping import torch from fastai.tabular.core import range_of import numpy as np import matplotlib import matplotlib.pyplot as plt from fastai.callback.core import Callback from fastai.data.core import DataLoaders import torch.nn.functional as F # from fastai.metrics import add_metrics import torch.nn as nn from fastcore.utils import ifnone import pandas as pd from models.base_model import ClassificationModel from models.basicconv1d import weight_init, fcn_wang, fcn, schirrmeister, sen, basic1d from models.inception1d import inception1d from models.resnet1d import resnet1d18, resnet1d34, resnet1d50, resnet1d101, resnet1d152, resnet1d_wang, \ wrn1d_22 from models.rnn1d import RNN1d from utilities.timeseries_utils import TimeseriesDatasetCrops, ToTensor, aggregate_predictions from models.xresnet1d import xresnet1d18_deeper, xresnet1d34_deeper, xresnet1d50_deeper, xresnet1d18_deep, \ xresnet1d34_deep, xresnet1d50_deep, xresnet1d18, xresnet1d34, xresnet1d101, xresnet1d50, xresnet1d152 from utilities.utils import evaluate_experiment def add_metrics(last_metrics, new_metric): """ Adds a new metric to the list of last metrics. Args: last_metrics (list): List of previous metrics. new_metric (float or list): New metric(s) to add. Returns: list: Updated list of metrics. """ if isinstance(new_metric, list): return last_metrics + new_metric else: return last_metrics + [new_metric] class MetricFunc(Callback): """Obtains score using user-supplied function func (potentially ignoring targets with ignore_idx)""" def __init__(self, func, name="MetricFunc", ignore_idx=None, one_hot_encode_target=True, argmax_pred=False, softmax_pred=True, flatten_target=True, sigmoid_pred=False, metric_component=None): super().__init__() self.metric_complete = self.func(self.y_true, self.y_pred) self.y_true = None self.y_pred = None self.func = func self.ignore_idx = ignore_idx self.one_hot_encode_target = one_hot_encode_target self.argmax_pred = argmax_pred self.softmax_pred = softmax_pred self.flatten_target = flatten_target self.sigmoid_pred = sigmoid_pred self.metric_component = metric_component self.name = name def on_epoch_begin(self, **kwargs): pass def on_batch_end(self, last_output, last_target, **kwargs): # flatten everything (to make it also work for annotation tasks) y_pred_flat = last_output.view((-1, last_output.size()[-1])) if self.flatten_target: last_target.view(-1) y_true_flat = last_target # optionally take argmax of predictions if self.argmax_pred is True: y_pred_flat = y_pred_flat.argmax(dim=1) elif self.softmax_pred is True: y_pred_flat = F.softmax(y_pred_flat, dim=1) elif self.sigmoid_pred is True: y_pred_flat = torch.sigmoid(y_pred_flat) # potentially remove ignore_idx entries if self.ignore_idx is not None: selected_indices = (y_true_flat != self.ignore_idx).nonzero().squeeze() y_pred_flat = y_pred_flat[selected_indices] y_true_flat = y_true_flat[selected_indices] y_pred_flat = to_np(y_pred_flat) y_true_flat = to_np(y_true_flat) if self.one_hot_encode_target is True: y_true_flat = np.one_hot_np(y_true_flat, last_output.size()[-1]) if self.y_pred is None: self.y_pred = y_pred_flat self.y_true = y_true_flat else: self.y_pred = np.concatenate([self.y_pred, y_pred_flat], axis=0) self.y_true = np.concatenate([self.y_true, y_true_flat], axis=0) def on_epoch_end(self, last_metrics, **kwargs): # access full metric (possibly multiple components) via self.metric_complete if self.metric_component is not None: return add_metrics(last_metrics, self.metric_complete[self.metric_component]) else: return add_metrics(last_metrics, self.metric_complete) def fmax_metric(targs, preds): return evaluate_experiment(targs, preds)["Fmax"] def auc_metric(targs, preds): return evaluate_experiment(targs, preds)["macro_auc"] def mse_flat(preds, targs): return torch.mean(torch.pow(preds.view(-1) - targs.view(-1), 2)) def nll_regression(preds, targs): # preds: bs, 2 # targs: bs, 1 preds_mean = preds[:, 0] # warning: output goes through exponential map to ensure positivity preds_var = torch.clamp(torch.exp(preds[:, 1]), 1e-4, 1e10) # print(to_np(preds_mean)[0],to_np(targs)[0,0],to_np(torch.sqrt(preds_var))[0]) return torch.mean(torch.log(2 * math.pi * preds_var) / 2) + torch.mean( torch.pow(preds_mean - targs[:, 0], 2) / 2 / preds_var) def nll_regression_init(m): assert (isinstance(m, nn.Linear)) nn.init.normal_(m.weight, 0., 0.001) nn.init.constant_(m.bias, 4) def lr_find_plot(learner, path, filename="lr_find", n_skip=10, n_skip_end=2): """ saves lr_find plot as file (normally only jupyter output) on the x-axis is lrs[-1] """ learner.lr_find() backend_old = matplotlib.get_backend() plt.switch_backend('agg') plt.ylabel("loss") plt.xlabel("learning rate (log scale)") losses = [to_np(x) for x in learner.recorder.losses[n_skip:-(n_skip_end + 1)]] # print(learner.recorder.val_losses) # val_losses = [ to_np(x) for x in learner.recorder.val_losses[n_skip:-(n_skip_end+1)]] plt.plot(learner.recorder.lrs[n_skip:-(n_skip_end + 1)], losses) # plt.plot(learner.recorder.lrs[n_skip:-(n_skip_end+1)],val_losses ) plt.xscale('log') plt.savefig(str(path / (filename + '.png'))) plt.switch_backend(backend_old) def losses_plot(learner, path, filename="losses", last: int = None): """ saves lr_find plot as file (normally only jupyter output) on the x-axis is lrs[-1] """ backend_old = matplotlib.get_backend() plt.switch_backend('agg') plt.ylabel("loss") plt.xlabel("Batches processed") last = ifnone(last, len(learner.recorder.nb_batches)) l_b = np.sum(learner.recorder.nb_batches[-last:]) iterations = range_of(learner.recorder.losses)[-l_b:] plt.plot(iterations, learner.recorder.losses[-l_b:], label='Train') val_iter = learner.recorder.nb_batches[-last:] val_iter = np.cumsum(val_iter) + np.sum(learner.recorder.nb_batches[:-last]) plt.plot(val_iter, learner.recorder.val_losses[-last:], label='Validation') plt.legend() plt.savefig(str(path / (filename + '.png'))) plt.switch_backend(backend_old) class FastaiModel(ClassificationModel): def __init__(self, name, n_classes, freq, output_folder, input_shape, pretrained=False, input_size=2.5, input_channels=12, chunkify_train=False, chunkify_valid=True, bs=128, ps_head=0.5, lin_ftrs_head=None, wd=1e-2, epochs=50, lr=1e-2, kernel_size=5, loss="binary_cross_entropy", pretrained_folder=None, n_classes_pretrained=None, gradual_unfreezing=True, discriminative_lrs=True, epochs_finetuning=30, early_stopping=None, aggregate_fn="max", concat_train_val=False): super().__init__() if lin_ftrs_head is None: lin_ftrs_head = [128] self.name = name self.num_classes = n_classes if loss != "nll_regression" else 2 self.target_fs = freq self.output_folder = Path(output_folder) self.input_size = int(input_size * self.target_fs) self.input_channels = input_channels self.chunkify_train = chunkify_train self.chunkify_valid = chunkify_valid self.chunk_length_train = 2 * self.input_size # target_fs*6 self.chunk_length_valid = self.input_size self.min_chunk_length = self.input_size # chunk_length self.stride_length_train = self.input_size # chunk_length_train//8 self.stride_length_valid = self.input_size // 2 # chunk_length_valid self.copies_valid = 0 # >0 should only be used with chunkify_valid=False self.bs = bs self.ps_head = ps_head self.lin_ftrs_head = lin_ftrs_head self.wd = wd self.epochs = epochs self.lr = lr self.kernel_size = kernel_size self.loss = loss self.input_shape = input_shape if pretrained: if pretrained_folder is None: pretrained_folder = Path('../output/exp0/models/' + name.split("_pretrained")[0] + '/') # pretrained_folder = Path('/output/exp0/models/'+name.split("_pretrained")[0]+'/') if n_classes_pretrained is None: n_classes_pretrained = 71 self.pretrained_folder = None if pretrained_folder is None else Path(pretrained_folder) self.n_classes_pretrained = n_classes_pretrained self.discriminative_lrs = discriminative_lrs self.gradual_unfreezing = gradual_unfreezing self.epochs_finetuning = epochs_finetuning self.early_stopping = early_stopping self.aggregate_fn = aggregate_fn self.concat_train_val = concat_train_val def fit(self, X_train, y_train, X_val, y_val): # convert everything to float32 X_train = [l.astype(np.float32) for l in X_train] X_val = [l.astype(np.float32) for l in X_val] y_train = [l.astype(np.float32) for l in y_train] y_val = [l.astype(np.float32) for l in y_val] if self.concat_train_val: X_train += X_val y_train += y_val if self.pretrained_folder is None: # from scratch print("Training from scratch...") learn = self._get_learner(X_train, y_train, X_val, y_val) # if(self.discriminative_lrs): # layer_groups=learn.model.get_layer_groups() # learn.split(layer_groups) learn.model.apply(weight_init) # initialization for regression output if self.loss == "nll_regression" or self.loss == "mse": output_layer_new = learn.model.get_output_layer() output_layer_new.apply(nll_regression_init) learn.model.set_output_layer(output_layer_new) lr_find_plot(learn, self.output_folder) learn.fit_one_cycle(self.epochs, self.lr) # slice(self.lr) if self.discriminative_lrs else self.lr) losses_plot(learn, self.output_folder) else: # finetuning print("Finetuning...") # create learner learn = self._get_learner(X_train, y_train, X_val, y_val, self.n_classes_pretrained) # load pretrained model learn.path = self.pretrained_folder learn.load(self.pretrained_folder.stem) learn.path = self.output_folder # exchange top layer output_layer = learn.model.get_output_layer() output_layer_new = nn.Linear(output_layer.in_features, self.num_classes).cuda() apply_init(output_layer_new, nn.init.kaiming_normal_) learn.model.set_output_layer(output_layer_new) # layer groups if self.discriminative_lrs: layer_groups = learn.model.get_layer_groups() learn.split(layer_groups) learn.train_bn = True # make sure if bn mode is train # train lr = self.lr if self.gradual_unfreezing: assert (self.discriminative_lrs is True) learn.freeze() lr_find_plot(learn, self.output_folder, "lr_find0") learn.fit_one_cycle(self.epochs_finetuning, lr) losses_plot(learn, self.output_folder, "losses0") # for n in [0]:#range(len(layer_groups)): learn.freeze_to(-n-1) lr_find_plot(learn, # self.output_folder,"lr_find"+str(n)) learn.fit_one_cycle(self.epochs_gradual_unfreezing,slice(lr)) # losses_plot(learn, self.output_folder,"losses"+str(n)) if(n==0):#reduce lr after first step lr/=10. # if(n>0 and (self.name.startswith("fastai_lstm") or self.name.startswith("fastai_gru"))):#reduce lr # further for RNNs lr/=10 learn.unfreeze() lr_find_plot(learn, self.output_folder, "lr_find" + str(len(layer_groups))) learn.fit_one_cycle(self.epochs_finetuning, slice(lr / 1000, lr / 10)) losses_plot(learn, self.output_folder, "losses" + str(len(layer_groups))) learn.save(self.name) # even for early stopping the best model will have been loaded again def predict(self, X): X = [l.astype(np.float32) for l in X] y_dummy = [np.ones(self.num_classes, dtype=np.float32) for _ in range(len(X))] learn = self._get_learner(X, y_dummy, X, y_dummy) learn.load(self.name) preds, targs = learn.get_preds() preds = to_np(preds) idmap = learn.data.valid_ds.get_id_mapping() return aggregate_predictions(preds, idmap=idmap, aggregate_fn=np.mean if self.aggregate_fn == "mean" else np.amax) def _get_learner(self, X_train, y_train, X_val, y_val, num_classes=None): df_train = pd.DataFrame({"data": range(len(X_train)), "label": y_train}) df_valid = pd.DataFrame({"data": range(len(X_val)), "label": y_val}) tfms_ptb_xl = [ToTensor()] ds_train = TimeseriesDatasetCrops(df_train, self.input_size, num_classes=self.num_classes, chunk_length=self.chunk_length_train if self.chunkify_train else 0, min_chunk_length=self.min_chunk_length, stride=self.stride_length_train, transforms=tfms_ptb_xl, annotation=False, col_lbl="label", npy_data=X_train) ds_valid = TimeseriesDatasetCrops(df_valid, self.input_size, num_classes=self.num_classes, chunk_length=self.chunk_length_valid if self.chunkify_valid else 0, min_chunk_length=self.min_chunk_length, stride=self.stride_length_valid, transforms=tfms_ptb_xl, annotation=False, col_lbl="label", npy_data=X_val) db = DataLoaders(ds_train, ds_valid) if self.loss == "binary_cross_entropy": loss = F.binary_cross_entropy_with_logits elif self.loss == "cross_entropy": loss = F.cross_entropy elif self.loss == "mse": loss = mse_flat elif self.loss == "nll_regression": loss = nll_regression else: print("loss not found") assert (True) self.input_channels = self.input_shape[-1] metrics = [] print("model:", self.name) # note: all models of a particular kind share the same prefix but potentially a different # postfix such as _input256 num_classes = self.num_classes if num_classes is None else num_classes # resnet resnet1d18,resnet1d34,resnet1d50,resnet1d101,resnet1d152,resnet1d_wang,resnet1d,wrn1d_22 if self.name.startswith("fastai_resnet1d18"): model = resnet1d18(num_classes=num_classes, input_channels=self.input_channels, inplanes=128, kernel_size=self.kernel_size, ps_head=self.ps_head, lin_ftrs_head=self.lin_ftrs_head) elif self.name.startswith("fastai_resnet1d34"): model = resnet1d34(num_classes=num_classes, input_channels=self.input_channels, inplanes=128, kernel_size=self.kernel_size, ps_head=self.ps_head, lin_ftrs_head=self.lin_ftrs_head) elif self.name.startswith("fastai_resnet1d50"): model = resnet1d50(num_classes=num_classes, input_channels=self.input_channels, inplanes=128, kernel_size=self.kernel_size, ps_head=self.ps_head, lin_ftrs_head=self.lin_ftrs_head) elif self.name.startswith("fastai_resnet1d101"): model = resnet1d101(num_classes=num_classes, input_channels=self.input_channels, inplanes=128, kernel_size=self.kernel_size, ps_head=self.ps_head, lin_ftrs_head=self.lin_ftrs_head) elif self.name.startswith("fastai_resnet1d152"): model = resnet1d152(num_classes=num_classes, input_channels=self.input_channels, inplanes=128, kernel_size=self.kernel_size, ps_head=self.ps_head, lin_ftrs_head=self.lin_ftrs_head) elif self.name.startswith("fastai_resnet1d_wang"): model = resnet1d_wang(num_classes=num_classes, input_channels=self.input_channels, kernel_size=self.kernel_size, ps_head=self.ps_head, lin_ftrs_head=self.lin_ftrs_head) elif self.name.startswith("fastai_wrn1d_22"): model = wrn1d_22(num_classes=num_classes, input_channels=self.input_channels, kernel_size=self.kernel_size, ps_head=self.ps_head, lin_ftrs_head=self.lin_ftrs_head) # xresnet ... (order important for string capture) elif self.name.startswith("fastai_xresnet1d18_deeper"): model = xresnet1d18_deeper(num_classes=num_classes, input_channels=self.input_channels, kernel_size=self.kernel_size, ps_head=self.ps_head, lin_ftrs_head=self.lin_ftrs_head) elif self.name.startswith("fastai_xresnet1d34_deeper"): model = xresnet1d34_deeper(num_classes=num_classes, input_channels=self.input_channels, kernel_size=self.kernel_size, ps_head=self.ps_head, lin_ftrs_head=self.lin_ftrs_head) elif self.name.startswith("fastai_xresnet1d50_deeper"): model = xresnet1d50_deeper(num_classes=num_classes, input_channels=self.input_channels, kernel_size=self.kernel_size, ps_head=self.ps_head, lin_ftrs_head=self.lin_ftrs_head) elif self.name.startswith("fastai_xresnet1d18_deep"): model = xresnet1d18_deep(num_classes=num_classes, input_channels=self.input_channels, kernel_size=self.kernel_size, ps_head=self.ps_head, lin_ftrs_head=self.lin_ftrs_head) elif self.name.startswith("fastai_xresnet1d34_deep"): model = xresnet1d34_deep(num_classes=num_classes, input_channels=self.input_channels, kernel_size=self.kernel_size, ps_head=self.ps_head, lin_ftrs_head=self.lin_ftrs_head) elif self.name.startswith("fastai_xresnet1d50_deep"): model = xresnet1d50_deep(num_classes=num_classes, input_channels=self.input_channels, kernel_size=self.kernel_size, ps_head=self.ps_head, lin_ftrs_head=self.lin_ftrs_head) elif self.name.startswith("fastai_xresnet1d18"): model = xresnet1d18(num_classes=num_classes, input_channels=self.input_channels, kernel_size=self.kernel_size, ps_head=self.ps_head, lin_ftrs_head=self.lin_ftrs_head) elif self.name.startswith("fastai_xresnet1d34"): model = xresnet1d34(num_classes=num_classes, input_channels=self.input_channels, kernel_size=self.kernel_size, ps_head=self.ps_head, lin_ftrs_head=self.lin_ftrs_head) elif self.name.startswith("fastai_xresnet1d50"): model = xresnet1d50(num_classes=num_classes, input_channels=self.input_channels, kernel_size=self.kernel_size, ps_head=self.ps_head, lin_ftrs_head=self.lin_ftrs_head) elif self.name.startswith("fastai_xresnet1d101"): model = xresnet1d101(num_classes=num_classes, input_channels=self.input_channels, kernel_size=self.kernel_size, ps_head=self.ps_head, lin_ftrs_head=self.lin_ftrs_head) elif self.name.startswith("fastai_xresnet1d152"): model = xresnet1d152(num_classes=num_classes, input_channels=self.input_channels, kernel_size=self.kernel_size, ps_head=self.ps_head, lin_ftrs_head=self.lin_ftrs_head) # inception passing the default kernel size of 5 leads to a max kernel size of 40-1 in the inception model as # proposed in the original paper elif self.name == "fastai_inception1d_no_residual": # note: order important for string capture model = inception1d(num_classes=num_classes, input_channels=self.input_channels, use_residual=False, ps_head=self.ps_head, lin_ftrs_head=self.lin_ftrs_head, kernel_size=8 * self.kernel_size) elif self.name.startswith("fastai_inception1d"): model = inception1d(num_classes=num_classes, input_channels=self.input_channels, use_residual=True, ps_head=self.ps_head, lin_ftrs_head=self.lin_ftrs_head, kernel_size=8 * self.kernel_size) # BasicConv1d fcn,fcn_wang,schirrmeister,sen,basic1d elif self.name.startswith("fastai_fcn_wang"): # note: order important for string capture model = fcn_wang(num_classes=num_classes, input_channels=self.input_channels, ps_head=self.ps_head, lin_ftrs_head=self.lin_ftrs_head) elif self.name.startswith("fastai_fcn"): model = fcn(num_classes=num_classes, input_channels=self.input_channels) elif self.name.startswith("fastai_schirrmeister"): model = schirrmeister(num_classes=num_classes, input_channels=self.input_channels, ps_head=self.ps_head, lin_ftrs_head=self.lin_ftrs_head) elif self.name.startswith("fastai_sen"): model = sen(num_classes=num_classes, input_channels=self.input_channels, ps_head=self.ps_head, lin_ftrs_head=self.lin_ftrs_head) elif self.name.startswith("fastai_basic1d"): model = basic1d(num_classes=num_classes, input_channels=self.input_channels, kernel_size=self.kernel_size, ps_head=self.ps_head, lin_ftrs_head=self.lin_ftrs_head) # RNN elif self.name.startswith("fastai_lstm_bidir"): model = RNN1d(input_channels=self.input_channels, num_classes=num_classes, lstm=True, bidirectional=True, ps_head=self.ps_head, lin_ftrs_head=self.lin_ftrs_head) elif self.name.startswith("fastai_gru_bidir"): model = RNN1d(input_channels=self.input_channels, num_classes=num_classes, lstm=False, bidirectional=True, ps_head=self.ps_head, lin_ftrs_head=self.lin_ftrs_head) elif self.name.startswith("fastai_lstm"): model = RNN1d(input_channels=self.input_channels, num_classes=num_classes, lstm=True, bidirectional=False, ps_head=self.ps_head, lin_ftrs_head=self.lin_ftrs_head) elif self.name.startswith("fastai_gru"): model = RNN1d(input_channels=self.input_channels, num_classes=num_classes, lstm=False, bidirectional=False, ps_head=self.ps_head, lin_ftrs_head=self.lin_ftrs_head) else: print("Model not found.") assert True learn = Learner(db, model, loss_func=loss, metrics=metrics, wd=self.wd, path=self.output_folder) if self.name.startswith("fastai_lstm") or self.name.startswith("fastai_gru"): learn.callback_fns.append(partial(GradientClipping, clip=0.25)) if self.early_stopping is not None: # supported options: valid_loss, macro_auc, fmax if self.early_stopping == "macro_auc" and self.loss != "mse" and self.loss != "nll_regression": metric = MetricFunc(auc_metric, self.early_stopping, one_hot_encode_target=False, argmax_pred=False, softmax_pred=False, sigmoid_pred=True, flatten_target=False) learn.metrics.append(metric) learn.callback_fns.append( partial(SaveModelCallback, monitor=self.early_stopping, every='improvement', name=self.name)) elif self.early_stopping == "fmax" and self.loss != "mse" and self.loss != "nll_regression": metric = MetricFunc(fmax_metric, self.early_stopping, one_hot_encode_target=False, argmax_pred=False, softmax_pred=False, sigmoid_pred=True, flatten_target=False) learn.metrics.append(metric) learn.callback_fns.append(partial(SaveModelCallback, monitor=self.early_stopping, every='improvement', name=self.name)) elif self.early_stopping == "valid_loss": learn.callback_fns.append(partial(SaveModelCallback, monitor=self.early_stopping, every='improvement', name=self.name)) return learn