| import logging |
| import os |
| import pickle |
|
|
| import skorch |
| import torch |
| from skorch.dataset import Dataset, ValidSplit |
| from skorch.setter import optimizer_setter |
| from skorch.utils import is_dataset, to_device |
|
|
| logger = logging.getLogger(__name__) |
| |
|
|
|
|
| class Net(skorch.NeuralNet): |
| def __init__( |
| self, |
| clip=0.25, |
| top_k=1, |
| correct=0, |
| save_embedding=False, |
| gene_embedds=[], |
| second_input_embedd=[], |
| confidence_threshold = 0.95, |
| *args, |
| **kwargs |
| ): |
| self.clip = clip |
| self.curr_epoch = 0 |
| super(Net, self).__init__(*args, **kwargs) |
| self.correct = correct |
| self.save_embedding = save_embedding |
| self.gene_embedds = gene_embedds |
| self.second_input_embedds = second_input_embedd |
| self.main_config = kwargs["module__main_config"] |
| self.train_config = self.main_config["train_config"] |
| self.top_k = self.train_config.top_k |
| self.num_classes = self.main_config["model_config"].num_classes |
| self.labels_mapping_path = self.train_config.labels_mapping_path |
| if self.labels_mapping_path: |
| with open(self.labels_mapping_path, 'rb') as handle: |
| self.labels_mapping_dict = pickle.load(handle) |
| self.confidence_threshold = confidence_threshold |
| self.max_epochs = kwargs["max_epochs"] |
| self.task = '' |
| self.log_tb = False |
|
|
| |
|
|
|
|
| def set_save_epoch(self): |
| ''' |
| scale best train epoch by valid size |
| ''' |
| if self.task !='tcga': |
| if self.train_split: |
| self.save_epoch = self.main_config["train_config"].train_epoch |
| else: |
| self.save_epoch = round(self.main_config["train_config"].train_epoch*\ |
| (1+self.main_config["valid_size"])) |
|
|
| def save_benchmark_model(self): |
| ''' |
| saves benchmark epochs when train_split is none |
| ''' |
| try: |
| os.mkdir("ckpt") |
| except: |
| pass |
| cwd = os.getcwd()+"/ckpt/" |
| self.save_params(f_params= f'{cwd}/model_params_{self.main_config["task"]}.pt') |
|
|
|
|
| def fit(self, X, y=None, valid_ds=None,**fit_params): |
| |
| self.all_lengths = [[] for i in range(self.num_classes)] |
| self.median_lengths = [] |
|
|
| if not self.warm_start or not self.initialized_: |
| self.initialize() |
|
|
| if valid_ds: |
| self.validation_dataset = valid_ds |
| else: |
| self.validation_dataset = None |
|
|
| self.partial_fit(X, y, **fit_params) |
| return self |
|
|
| def fit_loop(self, X, y=None, epochs=None, **fit_params): |
| |
| rounding_digits = 3 |
| if self.main_config['trained_on'] == 'full': |
| rounding_digits = 2 |
| self.check_data(X, y) |
| epochs = epochs if epochs is not None else self.max_epochs |
|
|
| dataset_train, dataset_valid = self.get_split_datasets(X, y, **fit_params) |
|
|
| if self.validation_dataset is not None: |
| dataset_valid = self.validation_dataset.keywords["valid_ds"] |
|
|
| on_epoch_kwargs = { |
| "dataset_train": dataset_train, |
| "dataset_valid": dataset_valid, |
| } |
|
|
| iterator_train = self.get_iterator(dataset_train, training=True) |
| iterator_valid = None |
| if dataset_valid is not None: |
| iterator_valid = self.get_iterator(dataset_valid, training=False) |
|
|
| self.set_save_epoch() |
|
|
| for epoch_no in range(epochs): |
| |
| self.curr_epoch = epoch_no |
| |
| |
| if self.task != 'tcga' and epoch_no == self.save_epoch and self.train_split == None: |
| self.save_benchmark_model() |
| |
| self.notify("on_epoch_begin", **on_epoch_kwargs) |
|
|
| self.run_single_epoch( |
| iterator_train, |
| training=True, |
| prefix="train", |
| step_fn=self.train_step, |
| **fit_params |
| ) |
|
|
| if dataset_valid is not None: |
| self.run_single_epoch( |
| iterator_valid, |
| training=False, |
| prefix="valid", |
| step_fn=self.validation_step, |
| **fit_params |
| ) |
| |
| |
| self.notify("on_epoch_end", **on_epoch_kwargs) |
| |
| if self.task == 'tcga': |
| train_acc = round(self.history[:,'train_acc'][-1],rounding_digits) |
| if train_acc == 1: |
| break |
|
|
|
|
| |
| return self |
|
|
| def train_step(self, X, y=None): |
| y = X[1] |
| X = X[0] |
| sample_weights = X[:,-1] |
| if self.device == 'cuda': |
| sample_weights = sample_weights.to(self.train_config.device) |
| self.module_.train() |
| self.module_.zero_grad() |
| gene_embedd, second_input_embedd, activations,_,_ = self.module_(X[:,:-1],train=True) |
| |
| loss = self.get_loss([gene_embedd,second_input_embedd,activations,self.curr_epoch], y) |
| |
| |
|
|
| loss = loss*sample_weights |
| loss = loss.mean() |
|
|
| loss.backward() |
|
|
| |
| torch.nn.utils.clip_grad_norm_(self.module_.parameters(), self.clip) |
| self.optimizer_.step() |
|
|
| return {"X":X,"y":y,"loss": loss, "y_pred": [gene_embedd,second_input_embedd,activations]} |
|
|
| def validation_step(self, X, y=None): |
| y = X[1] |
| X = X[0] |
| sample_weights = X[:,-1] |
| if self.device == 'cuda': |
| sample_weights = sample_weights.to(self.train_config.device) |
| self.module_.eval() |
| with torch.no_grad(): |
| gene_embedd, second_input_embedd, activations,_,_ = self.module_(X[:,:-1]) |
| loss = self.get_loss([gene_embedd,second_input_embedd,activations,self.curr_epoch], y) |
|
|
| |
|
|
| loss = loss*sample_weights |
| loss = loss.mean() |
|
|
| return {"X":X,"y":y,"loss": loss, "y_pred": [gene_embedd,second_input_embedd,activations]} |
|
|
| def get_attention_scores(self, X, y=None): |
| ''' |
| returns attention scores for a given input |
| ''' |
| self.module_.eval() |
| with torch.no_grad(): |
| _, _, _,attn_scores_first,attn_scores_second = self.module_(X[:,:-1]) |
|
|
| attn_scores_first = attn_scores_first.detach().cpu().numpy() |
| if attn_scores_second is not None: |
| attn_scores_second = attn_scores_second.detach().cpu().numpy() |
| return attn_scores_first,attn_scores_second |
|
|
| def predict(self, X): |
| self.module_.train(False) |
| embedds = self.module_(X[:,:-1]) |
| sample_weights = X[:,-1] |
| if self.device == 'cuda': |
| sample_weights = sample_weights.to(self.train_config.device) |
|
|
| gene_embedd, second_input_embedd, activations,_,_ = embedds |
| if self.save_embedding: |
| self.gene_embedds.append(gene_embedd.detach().cpu()) |
| |
| if second_input_embedd is not None: |
| self.second_input_embedds.append(second_input_embedd.detach().cpu()) |
| |
| predictions = torch.cat([activations,sample_weights[:,None]],dim=1) |
| return predictions |
|
|
|
|
| def on_epoch_end(self, net, dataset_train, dataset_valid, **kwargs): |
| |
| for _, m in self.module_.named_modules(): |
| for pn, p in m.named_parameters(): |
| if pn.endswith("weight") and pn.find("norm") < 0: |
| if p.grad != None: |
| if self.log_tb: |
| from ..callbacks.tbWriter import writer |
| writer.add_histogram("weights/" + pn, p, len(net.history)) |
| writer.add_histogram( |
| "gradients/" + pn, p.grad.data, len(net.history) |
| ) |
|
|
| return |
|
|
| def configure_opt(self, l2_weight_decay): |
| no_decay = ["bias", "LayerNorm.weight"] |
| params_decay = [ |
| p |
| for n, p in self.module_.named_parameters() |
| if not any(nd in n for nd in no_decay) |
| ] |
| params_nodecay = [ |
| p |
| for n, p in self.module_.named_parameters() |
| if any(nd in n for nd in no_decay) |
| ] |
| optim_groups = [ |
| {"params": params_decay, "weight_decay": l2_weight_decay}, |
| {"params": params_nodecay, "weight_decay": 0.0}, |
| ] |
| return optim_groups |
|
|
| def initialize_optimizer(self, triggered_directly=True): |
| """Initialize the model optimizer. If ``self.optimizer__lr`` |
| is not set, use ``self.lr`` instead. |
| |
| Parameters |
| ---------- |
| triggered_directly : bool (default=True) |
| Only relevant when optimizer is re-initialized. |
| Initialization of the optimizer can be triggered directly |
| (e.g. when lr was changed) or indirectly (e.g. when the |
| module was re-initialized). If and only if the former |
| happens, the user should receive a message informing them |
| about the parameters that caused the re-initialization. |
| |
| """ |
| |
| optimizer_params = self.main_config["train_config"] |
| kwargs = {} |
| kwargs["lr"] = optimizer_params.learning_rate |
| |
| args = self.configure_opt(optimizer_params.l2_weight_decay) |
|
|
| if self.initialized_ and self.verbose: |
| msg = self._format_reinit_msg( |
| "optimizer", kwargs, triggered_directly=triggered_directly |
| ) |
| logger.info(msg) |
|
|
| self.optimizer_ = self.optimizer(args, lr=kwargs["lr"]) |
|
|
| self._register_virtual_param( |
| ["optimizer__param_groups__*__*", "optimizer__*", "lr"], |
| optimizer_setter, |
| ) |
|
|
| def initialize_criterion(self): |
| """Initializes the criterion.""" |
| |
| |
| self.criterion_ = self.criterion( |
| self.main_config |
| ) |
| if isinstance(self.criterion_, torch.nn.Module): |
| self.criterion_ = to_device(self.criterion_, self.device) |
| return self |
|
|
| def initialize_callbacks(self): |
| """Initializes all callbacks and save the result in the |
| ``callbacks_`` attribute. |
| |
| Both ``default_callbacks`` and ``callbacks`` are used (in that |
| order). Callbacks may either be initialized or not, and if |
| they don't have a name, the name is inferred from the class |
| name. The ``initialize`` method is called on all callbacks. |
| |
| The final result will be a list of tuples, where each tuple |
| consists of a name and an initialized callback. If names are |
| not unique, a ValueError is raised. |
| |
| """ |
| if self.callbacks == "disable": |
| self.callbacks_ = [] |
| return self |
|
|
| callbacks_ = [] |
|
|
| class Dummy: |
| |
| |
| pass |
|
|
| for name, cb in self._uniquely_named_callbacks(): |
| |
| param_callback = getattr(self, "callbacks__" + name, Dummy) |
| if param_callback is not Dummy: |
| cb = param_callback |
|
|
| |
| |
|
|
| |
| |
| if name == "lrcallback": |
| params["config"] = self.main_config["train_config"] |
| else: |
| params = self.get_params_for("callbacks__{}".format(name)) |
| if (cb is None) and params: |
| raise ValueError( |
| "Trying to set a parameter for callback {} " |
| "which does not exist.".format(name) |
| ) |
| if cb is None: |
| continue |
|
|
| if isinstance(cb, type): |
| cb = cb(**params) |
| else: |
| cb.set_params(**params) |
| cb.initialize() |
| callbacks_.append((name, cb)) |
|
|
| self.callbacks_ = callbacks_ |
|
|
| return self |
|
|