Spaces:
Sleeping
Sleeping
| from math import floor | |
| import re | |
| from typing import Literal | |
| import numpy as np | |
| import torch.nn as nn | |
| import torch | |
| import torch.nn.functional as F | |
| def conv(in_channels, out_channels, kernel_size, conv_dim, stride=1): | |
| conv_layer = None | |
| match conv_dim: | |
| case 1: | |
| conv_layer = nn.Conv1d | |
| case 2: | |
| conv_layer = nn.Conv2d | |
| case 3: | |
| conv_layer = nn.Conv3d | |
| return conv_layer(in_channels, out_channels, | |
| kernel_size=kernel_size, stride=stride, padding=floor(kernel_size / 2), bias=False) | |
| def batch_norm(out_channels, conv_dim): | |
| bn_layer = None | |
| match conv_dim: | |
| case 1: | |
| bn_layer = nn.BatchNorm1d | |
| case 2: | |
| bn_layer = nn.BatchNorm2d | |
| case 3: | |
| bn_layer = nn.BatchNorm3d | |
| return bn_layer(out_channels) | |
| def conv3x3(in_channels, out_channels, stride=1): | |
| return nn.Conv2d(in_channels, out_channels, kernel_size=3, | |
| stride=stride, padding=1, bias=False) | |
| def conv5x5(in_channels, out_channels, stride=1): | |
| return nn.Conv2d(in_channels, out_channels, kernel_size=5, | |
| stride=stride, padding=2, bias=False) | |
| def conv1x1(in_channels, out_channels, stride=1): | |
| return nn.Conv2d(in_channels, out_channels, kernel_size=1, | |
| stride=stride, padding=0, bias=False) | |
| # Residual block | |
| class ResidualBlock(nn.Module): | |
| def __init__(self, in_channels, out_channels, conv_dim, stride=1, downsample=None): | |
| super().__init__() | |
| # self.conv1 = conv5x5(in_channels, out_channels, stride) | |
| self.conv1 = conv(in_channels, out_channels, kernel_size=5, conv_dim=conv_dim, stride=stride) | |
| self.bn1 = batch_norm(out_channels, conv_dim=conv_dim) | |
| self.elu = nn.ELU(inplace=True) | |
| # self.conv2 = conv3x3(out_channels, out_channels) | |
| self.conv2 = conv(out_channels, out_channels, kernel_size=3, conv_dim=conv_dim, stride=stride) | |
| self.bn2 = batch_norm(out_channels, conv_dim=conv_dim) | |
| self.downsample = downsample | |
| def forward(self, x): | |
| residual = x | |
| out = self.conv1(x) | |
| out = self.bn1(out) | |
| out = self.elu(out) | |
| out = self.conv2(out) | |
| out = self.bn2(out) | |
| if self.downsample: | |
| residual = self.downsample(x) | |
| out += residual | |
| out = self.elu(out) | |
| return out | |
| class DrugVQA(nn.Module): | |
| """ | |
| The class is an implementation of the DrugVQA model including regularization and without pruning. | |
| Slight modifications have been done for speedup | |
| """ | |
| def __init__( | |
| self, | |
| conv_dim: Literal[1, 2, 3], | |
| lstm_hid_dim: int, | |
| d_a: int, | |
| r: int, | |
| n_chars_smi: int, | |
| n_chars_seq: int, | |
| dropout: float, | |
| in_channels: int, | |
| cnn_channels: int, | |
| cnn_layers: int, | |
| emb_dim: int, | |
| dense_hid: int, | |
| ): | |
| """ | |
| lstm_hid_dim: {int} hidden dimension for lstm | |
| d_a : {int} hidden dimension for the dense layer | |
| r : {int} attention-hops or attention heads | |
| n_chars_smi : {int} voc size of smiles | |
| n_chars_seq : {int} voc size of protein sequence | |
| dropout : {float} | |
| in_channels : {int} channels of CNN block input | |
| cnn_channels: {int} channels of CNN block | |
| cnn_layers : {int} num of layers of each CNN block | |
| emb_dim : {int} embeddings dimension | |
| dense_hid : {int} hidden dim for the output dense | |
| """ | |
| super().__init__() | |
| self.conv_dim = conv_dim | |
| self.lstm_hid_dim = lstm_hid_dim | |
| self.r = r | |
| self.in_channels = in_channels | |
| # rnn | |
| self.embeddings = nn.Embedding(n_chars_smi, emb_dim) | |
| # self.seq_embed = nn.Embedding(n_chars_seq, emb_dim) | |
| self.lstm = nn.LSTM(emb_dim, self.lstm_hid_dim, 2, batch_first=True, bidirectional=True, | |
| dropout=dropout) | |
| self.linear_first = nn.Linear(2 * self.lstm_hid_dim, d_a) | |
| self.linear_second = nn.Linear(d_a, r) | |
| self.linear_first_seq = nn.Linear(cnn_channels, d_a) | |
| self.linear_second_seq = nn.Linear(d_a, self.r) | |
| # cnn | |
| # self.conv = conv3x3(1, self.in_channels) | |
| self.conv = conv(1, self.in_channels, kernel_size=3, conv_dim=conv_dim) | |
| self.bn = batch_norm(in_channels, conv_dim=conv_dim) | |
| self.elu = nn.ELU(inplace=False) | |
| self.layer1 = self.make_layer(cnn_channels, cnn_layers) | |
| self.layer2 = self.make_layer(cnn_channels, cnn_layers) | |
| self.linear_final_step = nn.Linear(self.lstm_hid_dim * 2 + d_a, dense_hid) | |
| # self.linear_final = nn.Linear(dense_hid, n_classes) | |
| self.softmax = nn.Softmax(dim=1) | |
| # @staticmethod | |
| # def softmax(input, axis=1): | |
| # """ | |
| # Softmax applied to axis=n | |
| # Args: | |
| # input: {Tensor,Variable} input on which softmax is to be applied | |
| # axis : {int} axis on which softmax is to be applied | |
| # | |
| # Returns: | |
| # softmaxed tensors | |
| # """ | |
| # input_size = input.size() | |
| # trans_input = input.transpose(axis, len(input_size) - 1) | |
| # trans_size = trans_input.size() | |
| # input_2d = trans_input.contiguous().view(-1, trans_size[-1]) | |
| # soft_max_2d = F.softmax(input_2d) | |
| # soft_max_nd = soft_max_2d.view(*trans_size) | |
| # return soft_max_nd.transpose(axis, len(input_size) - 1) | |
| def make_layer(self, out_channels, blocks, stride=1): | |
| downsample = None | |
| if (stride != 1) or (self.in_channels != out_channels): | |
| downsample = nn.Sequential( | |
| # conv3x3(self.in_channels, out_channels, stride=stride), | |
| conv(self.in_channels, out_channels, kernel_size=3, conv_dim=self.conv_dim, stride=stride), | |
| batch_norm(out_channels, conv_dim=self.conv_dim) | |
| ) | |
| layers = [ResidualBlock(self.in_channels, out_channels, | |
| conv_dim=self.conv_dim, stride=stride, downsample=downsample)] | |
| self.in_channels = out_channels | |
| for i in range(1, blocks): | |
| layers.append(ResidualBlock(out_channels, out_channels, conv_dim=self.conv_dim)) | |
| return nn.Sequential(*layers) | |
| def forward(self, enc_drug, enc_protein): | |
| enc_drug, _ = enc_drug | |
| enc_protein, _ = enc_protein | |
| smile_embed = self.embeddings(enc_drug.long()) | |
| # self.hidden_state = tuple(hidden_state.to(smile_embed).detach() for hidden_state in self.hidden_state) | |
| outputs, hidden_state = self.lstm(smile_embed) | |
| sentence_att = F.tanh(self.linear_first(outputs)) | |
| sentence_att = self.linear_second(sentence_att) | |
| sentence_att = self.softmax(sentence_att) | |
| sentence_att = sentence_att.transpose(1, 2) | |
| sentence_embed = sentence_att @ outputs | |
| avg_sentence_embed = torch.sum(sentence_embed, 1) / self.r # multi head | |
| pic = self.conv(enc_protein.float().unsqueeze(1)) | |
| pic = self.bn(pic) | |
| pic = self.elu(pic) | |
| pic = self.layer1(pic) | |
| pic = self.layer2(pic) | |
| pic_emb = torch.mean(pic, 2).unsqueeze(2) | |
| pic_emb = pic_emb.permute(0, 2, 1) | |
| seq_att = F.tanh(self.linear_first_seq(pic_emb)) | |
| seq_att = self.linear_second_seq(seq_att) | |
| seq_att = self.softmax(seq_att) | |
| seq_att = seq_att.transpose(1, 2) | |
| seq_embed = seq_att @ pic_emb | |
| avg_seq_embed = torch.sum(seq_embed, 1) / self.r | |
| sscomplex = torch.cat([avg_sentence_embed, avg_seq_embed], dim=1) | |
| sscomplex = F.relu(self.linear_final_step(sscomplex)) | |
| # if not bool(self.type): | |
| # output = F.sigmoid(self.linear_final(sscomplex)) | |
| # return output, seq_att | |
| # else: | |
| # return F.log_softmax(self.linear_final(sscomplex)), seq_att | |
| return sscomplex, seq_att | |
| class AttentionL2Regularization(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| def forward(self, seq_att): | |
| batch_size = seq_att.size(0) | |
| identity = torch.eye(seq_att.size(1), device=seq_att.device) | |
| identity = identity.unsqueeze(0).expand(batch_size, seq_att.size(1), seq_att.size(1)) | |
| loss = torch.mean(self.l2_matrix_norm(seq_att @ seq_att.transpose(1, 2) - identity)) | |
| return loss | |
| def l2_matrix_norm(m): | |
| """ | |
| m = ||A * A_T - I|| | |
| Missing from the original DrugVQA GitHub source code. | |
| Opting to use the faster Frobenius norm rather than the induced L2 matrix norm (spectral norm) | |
| proposed in the original research, because the goal is to minimize the difference between | |
| the attention matrix and the identity matrix. | |
| """ | |
| return torch.linalg.norm(m, ord='fro', dim=(1, 2)) | |