Spaces:
Sleeping
Sleeping
| from functools import partial | |
| from typing import Optional, Sequence, Dict | |
| from torch import nn, optim, Tensor | |
| from lightning import LightningModule | |
| from torchmetrics import Metric, MetricCollection | |
| class DTILightningModule(LightningModule): | |
| """ | |
| Drug Target Interaction Prediction | |
| optimizer: a partially or fully initialized instance of class torch.optim.Optimizer | |
| drug_encoder: a fully initialized instance of class torch.nn.Module | |
| protein_encoder: a fully initialized instance of class torch.nn.Module | |
| classifier: a fully initialized instance of class torch.nn.Module | |
| model: a fully initialized instance of class torch.nn.Module | |
| metrics: a list of fully initialized instances of class torchmetrics.Metric | |
| """ | |
| extra_return_keys = ['ID1', 'X1', 'ID2', 'X2', 'N'] | |
| def __init__( | |
| self, | |
| optimizer: optim.Optimizer, | |
| scheduler: Optional[optim.lr_scheduler | Dict], | |
| predictor: nn.Module, | |
| metrics: Optional[Dict[str, Metric]] = (), | |
| out: nn.Module = None, | |
| loss: nn.Module = None, | |
| activation: nn.Module = None, | |
| ): | |
| super().__init__() | |
| self.predictor = predictor | |
| self.out = out | |
| self.loss = loss | |
| self.activation = activation | |
| # Automatically averaged over batches: | |
| # Separate metric instances for train, val and test step to ensure a proper reduction over the epoch | |
| metrics = MetricCollection(dict(metrics)) | |
| self.train_metrics = metrics.clone(prefix="train/") | |
| self.val_metrics = metrics.clone(prefix="val/") | |
| self.test_metrics = metrics.clone(prefix="test/") | |
| # allows access to init params with 'self.hparams' attribute and ensures init params will be stored in ckpt | |
| self.save_hyperparameters(logger=False, | |
| ignore=['predictor', 'out', 'loss', 'activation', 'metrics']) | |
| def setup(self, stage): | |
| match stage: | |
| case 'fit': | |
| dataloader = self.trainer.datamodule.train_dataloader() | |
| dummy_batch = next(iter(dataloader)) | |
| self.forward(dummy_batch) | |
| # case 'validate': | |
| # dataloader = self.trainer.datamodule.val_dataloader() | |
| # case 'test': | |
| # dataloader = self.trainer.datamodule.test_dataloader() | |
| # case 'predict': | |
| # dataloader = self.trainer.datamodule.predict_dataloader() | |
| # for key, value in dummy_batch.items(): | |
| # if isinstance(value, Tensor): | |
| # dummy_batch[key] = value.to(self.device) | |
| def forward(self, batch): | |
| output = self.predictor(batch['X1^'], batch['X2^']) | |
| target = batch.get('Y') | |
| indexes = batch.get('ID^') | |
| preds = None | |
| loss = None | |
| if isinstance(output, Tensor): | |
| output = self.out(output).squeeze(1) | |
| preds = self.activation(output) | |
| elif isinstance(output, Sequence): | |
| output = list(output) | |
| # If multi-objective, assume the zeroth element in `output` is main while the rest are auxiliary | |
| output[0] = self.out(output[0]).squeeze(1) | |
| # Downstream metrics evaluation only needs main-objective preds | |
| preds = self.activation(output[0]) | |
| if target is not None: | |
| loss = self.loss(output, target.float()) | |
| return preds, target, indexes, loss | |
| def training_step(self, batch, batch_idx): | |
| preds, target, indexes, loss = self.forward(batch) | |
| self.log('train/loss', loss, on_step=False, on_epoch=True, prog_bar=True, sync_dist=True) | |
| self.train_metrics(preds=preds, target=target, indexes=indexes.long()) | |
| self.log_dict(self.train_metrics, on_step=False, on_epoch=True, prog_bar=True, sync_dist=True) | |
| return_dict = { | |
| 'Y^': preds, | |
| 'Y': target, | |
| 'loss': loss | |
| } | |
| for key in self.extra_return_keys: | |
| if key in batch: | |
| return_dict[key] = batch[key] | |
| return return_dict | |
| def on_train_epoch_end(self): | |
| pass | |
| def validation_step(self, batch, batch_idx): | |
| preds, target, indexes, loss = self.forward(batch) | |
| self.log('val/loss', loss, on_step=False, on_epoch=True, prog_bar=True, sync_dist=True) | |
| self.val_metrics(preds=preds, target=target, indexes=indexes.long()) | |
| self.log_dict(self.val_metrics, on_step=False, on_epoch=True, prog_bar=True, sync_dist=True) | |
| return_dict = { | |
| 'Y^': preds, | |
| 'Y': target, | |
| 'loss': loss | |
| } | |
| for key in self.extra_return_keys: | |
| if key in batch: | |
| return_dict[key] = batch[key] | |
| return return_dict | |
| def on_validation_epoch_end(self): | |
| pass | |
| def test_step(self, batch, batch_idx): | |
| preds, target, indexes, loss = self.forward(batch) | |
| self.log('test/loss', loss, on_step=False, on_epoch=True, prog_bar=True, sync_dist=True) | |
| self.test_metrics(preds=preds, target=target, indexes=indexes.long()) | |
| self.log_dict(self.test_metrics, on_step=False, on_epoch=True, prog_bar=True, sync_dist=True) | |
| return_dict = { | |
| 'Y^': preds, | |
| 'Y': target, | |
| 'loss': loss | |
| } | |
| for key in self.extra_return_keys: | |
| if key in batch: | |
| return_dict[key] = batch[key] | |
| return return_dict | |
| def on_test_epoch_end(self): | |
| pass | |
| def predict_step(self, batch, batch_idx, dataloader_idx=0): | |
| preds, _, _, _ = self.forward(batch) | |
| # return a dictionary for callbacks like BasePredictionWriter | |
| return_dict = { | |
| 'Y^': preds, | |
| } | |
| for key in self.extra_return_keys: | |
| if key in batch: | |
| return_dict[key] = batch[key] | |
| return return_dict | |
| def configure_optimizers(self): | |
| optimizers_config = {'optimizer': self.hparams.optimizer(params=self.parameters())} | |
| if self.hparams.get('scheduler'): | |
| if isinstance(self.hparams.scheduler, partial): | |
| optimizers_config['lr_scheduler'] = { | |
| "scheduler": self.hparams.scheduler(optimizer=optimizers_config['optimizer']), | |
| "monitor": "val/loss", | |
| "interval": "epoch", | |
| "frequency": 1, | |
| } | |
| else: | |
| self.hparams.scheduler['scheduler'] = self.hparams.scheduler['scheduler']( | |
| optimizer=optimizers_config['optimizer'] | |
| ) | |
| optimizers_config['lr_scheduler'] = dict(self.hparams.scheduler) | |
| return optimizers_config | |