| import os |
|
|
| import numpy as np |
| import skorch |
| import torch |
| from sklearn.metrics import confusion_matrix, make_scorer |
| from skorch.callbacks import BatchScoring |
| from skorch.callbacks.scoring import ScoringBase, _cache_net_forward_iter |
| from skorch.callbacks.training import Checkpoint |
|
|
| from .LRCallback import LearningRateDecayCallback |
|
|
| writer = None |
|
|
| def accuracy_score(y_true, y_pred: torch.tensor,task:str=None,mirna_flag:bool = False): |
| |
| |
| |
| if task == "premirna": |
| y_pred = y_pred[:,:-1] |
| miRNA_idx = np.where(y_true.squeeze()==mirna_flag) |
| correct = torch.max(y_pred,1).indices.cpu().numpy()[miRNA_idx] == mirna_flag |
| return sum(correct) |
|
|
| |
| if task == "sncrna": |
| y_pred = y_pred[:,:-1] |
| |
| correct = torch.max(y_pred,1).indices.cpu().numpy() == y_true.squeeze() |
|
|
| return sum(correct) / y_pred.shape[0] |
|
|
|
|
| def accuracy_score_tcga(y_true, y_pred): |
| |
| if torch.is_tensor(y_pred): |
| y_pred = y_pred.clone().detach().cpu().numpy() |
| if torch.is_tensor(y_true): |
| y_true = y_true.clone().detach().cpu().numpy() |
| |
| |
| sample_weight = y_pred[:,-1] |
| y_pred = np.argmax(y_pred[:,:-1],axis=1) |
|
|
| C = confusion_matrix(y_true, y_pred, sample_weight=sample_weight) |
| with np.errstate(divide='ignore', invalid='ignore'): |
| per_class = np.diag(C) / C.sum(axis=1) |
| if np.any(np.isnan(per_class)): |
| per_class = per_class[~np.isnan(per_class)] |
| score = np.mean(per_class) |
| return score |
|
|
| def score_callbacks(cfg): |
|
|
| acc_scorer = make_scorer(accuracy_score,task=cfg["task"]) |
| if cfg['task'] == 'tcga': |
| acc_scorer = make_scorer(accuracy_score_tcga) |
|
|
|
|
| if cfg["task"] == "premirna": |
| acc_scorer_mirna = make_scorer(accuracy_score,task=cfg["task"],mirna_flag = True) |
| |
| val_score_callback_mirna = BatchScoringPremirna( mirna_flag=True, |
| scoring = acc_scorer_mirna, lower_is_better=False, name="val_acc_mirna") |
|
|
| train_score_callback_mirna = BatchScoringPremirna(mirna_flag=True, |
| scoring = acc_scorer_mirna, on_train=True, lower_is_better=False, name="train_acc_mirna") |
| |
| val_score_callback = BatchScoringPremirna(mirna_flag=False, |
| scoring = acc_scorer, lower_is_better=False, name="val_acc") |
|
|
| train_score_callback = BatchScoringPremirna(mirna_flag=False, |
| scoring = acc_scorer, on_train=True, lower_is_better=False, name="train_acc") |
|
|
|
|
| scoring_callbacks = [ |
| train_score_callback, |
| train_score_callback_mirna |
| ] |
| if cfg["train_split"]: |
| scoring_callbacks.extend([val_score_callback_mirna,val_score_callback]) |
| |
| if cfg["task"] in ["sncrna", "tcga"]: |
|
|
| val_score_callback = BatchScoring(acc_scorer, lower_is_better=False, name="val_acc") |
| train_score_callback = BatchScoring( |
| acc_scorer, on_train=True, lower_is_better=False, name="train_acc" |
| ) |
| scoring_callbacks = [train_score_callback] |
|
|
| |
| |
| if cfg["train_split"] or cfg['task'] == 'tcga': |
| scoring_callbacks.append(val_score_callback) |
|
|
| return scoring_callbacks |
|
|
| def get_callbacks(path,cfg): |
|
|
| callback_list = [("lrcallback", LearningRateDecayCallback)] |
| if cfg['tensorboard'] == True: |
| from .tbWriter import writer |
| callback_list.append(MetricsVizualization) |
|
|
| if (cfg["train_split"] or cfg['task'] == 'tcga') and cfg['inference'] == False: |
| monitor = "val_acc_best" |
| if cfg['trained_on'] == 'full': |
| monitor = 'train_acc_best' |
| ckpt_path = path+"/ckpt/" |
| try: |
| os.mkdir(ckpt_path) |
| except: |
| pass |
| model_name = f'model_params_{cfg["task"]}.pt' |
| callback_list.append(Checkpoint(monitor=monitor, dirname=ckpt_path,f_params=model_name)) |
|
|
| scoring_callbacks = score_callbacks(cfg) |
| |
| |
| callback_list[1:1] = scoring_callbacks |
|
|
| return callback_list |
|
|
|
|
| class MetricsVizualization(skorch.callbacks.Callback): |
| def __init__(self, batch_idx=0) -> None: |
| super().__init__() |
| self.batch_idx = batch_idx |
|
|
| |
| def on_batch_end(self, net, training, **kwargs): |
| |
| if not training: |
| |
| writer.add_scalar( |
| "Accuracy/val_acc", |
| net.history[-1, "batches", -1, "val_acc"], |
| self.batch_idx, |
| ) |
| |
| writer.add_scalar( |
| "Loss/val_loss", |
| net.history[-1, "batches", -1, "valid_loss"], |
| self.batch_idx, |
| ) |
| |
| |
| else: |
| |
| writer.add_scalar("Metrics/lr", net.lr, self.batch_idx) |
| |
| writer.add_scalar( |
| "Accuracy/train_acc", |
| net.history[-1, "batches", -1, "train_acc"], |
| self.batch_idx, |
| ) |
| |
| writer.add_scalar( |
| "Loss/train_loss", |
| net.history[-1, "batches", -1, "train_loss"], |
| self.batch_idx, |
| ) |
| self.batch_idx += 1 |
|
|
| class BatchScoringPremirna(ScoringBase): |
| def __init__(self,mirna_flag:bool = False,*args,**kwargs): |
| super().__init__(*args,**kwargs) |
| |
| self.total_num_samples = 0 |
| self.mirna_flag = mirna_flag |
| self.first_batch_flag = True |
| def on_batch_end(self, net, X, y, training, **kwargs): |
| if training != self.on_train: |
| return |
|
|
| y_preds = [kwargs['y_pred']] |
| |
| if self.first_batch_flag: |
| self.total_num_samples += sum(kwargs["batch"][1] == self.mirna_flag).detach().cpu().numpy()[0] |
|
|
| with _cache_net_forward_iter(net, self.use_caching, y_preds) as cached_net: |
| |
| |
| y = None if y is None else self.target_extractor(y) |
| try: |
| score = self._scoring(cached_net, X, y) |
| cached_net.history.record_batch(self.name_, score) |
| except KeyError: |
| pass |
| def get_avg_score(self, history): |
| if self.on_train: |
| bs_key = 'train_batch_size' |
| else: |
| bs_key = 'valid_batch_size' |
|
|
| weights, scores = list(zip( |
| *history[-1, 'batches', :, [bs_key, self.name_]])) |
| |
| score_avg = sum(scores)/self.total_num_samples |
| return score_avg |
|
|
| |
| def on_epoch_end(self, net, **kwargs): |
| self.first_batch_flag = False |
| history = net.history |
| try: |
| history[-1, 'batches', :, self.name_] |
| except KeyError: |
| return |
|
|
| score_avg = self.get_avg_score(history) |
| is_best = self._is_best_score(score_avg) |
| if is_best: |
| self.best_score_ = score_avg |
|
|
| history.record(self.name_, score_avg) |
| if is_best is not None: |
| history.record(self.name_ + '_best', bool(is_best)) |
|
|