Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import math | |
| class TransformerCPI(nn.Module): | |
| def __init__(self, protein_dim, hidden_dim, n_layers, kernel_size, dropout, n_heads, pf_dim, atom_dim=34): | |
| super().__init__() | |
| self.encoder = Encoder(protein_dim, hidden_dim, n_layers, kernel_size, dropout) | |
| self.decoder = Decoder(atom_dim, hidden_dim, n_layers, n_heads, pf_dim, dropout) | |
| self.weight = nn.Parameter(torch.FloatTensor(atom_dim, atom_dim)) | |
| self.init_weight() | |
| def init_weight(self): | |
| stdv = 1. / math.sqrt(self.weight.size(1)) | |
| self.weight.data.uniform_(-stdv, stdv) | |
| def gcn(self, input, adj): | |
| # input =[batch,num_node, atom_dim] | |
| # adj = [batch,num_node, num_node] | |
| support = torch.matmul(input, self.weight) | |
| # support =[batch,num_node,atom_dim] | |
| output = torch.bmm(adj.float(), support.float()) | |
| # output = [batch,num_node,atom_dim] | |
| return output | |
| def forward(self, compound, protein): | |
| compound, adj = compound | |
| compound, compound_lengths = compound | |
| adj, _ = adj | |
| protein, protein_lengths = protein | |
| # compound = [batch,atom_num, atom_dim] | |
| # adj = [batch,atom_num, atom_num] | |
| # protein = [batch,protein len, 100] | |
| compound_mask = torch.arange(compound.size(1), device=compound.device) >= compound_lengths.unsqueeze(1) | |
| protein_mask = torch.arange(protein.size(1), device=protein.device) >= protein_lengths.unsqueeze(1) | |
| compound_mask = compound_mask.unsqueeze(1).unsqueeze(3) | |
| protein_mask = protein_mask.unsqueeze(1).unsqueeze(2) | |
| compound = self.gcn(compound.float(), adj) | |
| # compound = torch.unsqueeze(compound, dim=0) | |
| # compound = [batch size=1 ,atom_num, atom_dim] | |
| # protein = torch.unsqueeze(protein, dim=0) | |
| # protein =[ batch size=1,protein len, protein_dim] | |
| enc_src = self.encoder(protein) | |
| # enc_src = [batch size, protein len, hid dim] | |
| out = self.decoder(compound, enc_src, compound_mask, protein_mask) | |
| # out = [batch size, 2] | |
| # out = torch.squeeze(out, dim=0) | |
| return out | |
| class SelfAttention(nn.Module): | |
| def __init__(self, hidden_dim, n_heads, dropout): | |
| super().__init__() | |
| self.hidden_dim = hidden_dim | |
| self.n_heads = n_heads | |
| assert hidden_dim % n_heads == 0 | |
| self.w_q = nn.Linear(hidden_dim, hidden_dim) | |
| self.w_k = nn.Linear(hidden_dim, hidden_dim) | |
| self.w_v = nn.Linear(hidden_dim, hidden_dim) | |
| self.fc = nn.Linear(hidden_dim, hidden_dim) | |
| self.do = nn.Dropout(dropout) | |
| self.scale = (hidden_dim // n_heads) ** 0.5 | |
| def forward(self, query, key, value, mask=None): | |
| bsz = query.shape[0] | |
| # query = key = value [batch size, sent len, hid dim] | |
| q = self.w_q(query) | |
| k = self.w_k(key) | |
| v = self.w_v(value) | |
| # q, k, v = [batch size, sent len, hid dim] | |
| q = q.view(bsz, -1, self.n_heads, self.hidden_dim // self.n_heads).permute(0, 2, 1, 3) | |
| k = k.view(bsz, -1, self.n_heads, self.hidden_dim // self.n_heads).permute(0, 2, 1, 3) | |
| v = v.view(bsz, -1, self.n_heads, self.hidden_dim // self.n_heads).permute(0, 2, 1, 3) | |
| # k, v = [batch size, n heads, sent len_K, hid dim // n heads] | |
| # q = [batch size, n heads, sent len_q, hid dim // n heads] | |
| energy = torch.matmul(q, k.permute(0, 1, 3, 2)) / self.scale | |
| # energy = [batch size, n heads, sent len_Q, sent len_K] | |
| if mask is not None: | |
| energy = energy.masked_fill(mask == 0, -1e10) | |
| attention = self.do(F.softmax(energy, dim=-1)) | |
| # attention = [batch size, n heads, sent len_Q, sent len_K] | |
| x = torch.matmul(attention, v) | |
| # x = [batch size, n heads, sent len_Q, hid dim // n heads] | |
| x = x.permute(0, 2, 1, 3).contiguous() | |
| # x = [batch size, sent len_Q, n heads, hid dim // n heads] | |
| x = x.view(bsz, -1, self.n_heads * (self.hidden_dim // self.n_heads)) | |
| # x = [batch size, src sent len_Q, hid dim] | |
| x = self.fc(x) | |
| # x = [batch size, sent len_Q, hid dim] | |
| return x | |
| class Encoder(nn.Module): | |
| """protein feature extraction.""" | |
| def __init__(self, protein_dim, hidden_dim, n_layers, kernel_size, dropout): | |
| super().__init__() | |
| assert kernel_size % 2 == 1, "Kernel size must be odd (for now)" | |
| self.input_dim = protein_dim | |
| self.hidden_dim = hidden_dim | |
| self.kernel_size = kernel_size | |
| self.dropout = dropout | |
| self.n_layers = n_layers | |
| # self.pos_embedding = nn.Embedding(1000, hidden_dim) | |
| self.scale = 0.5 ** 0.5 | |
| self.convs = nn.ModuleList( | |
| [nn.Conv1d(hidden_dim, 2 * hidden_dim, kernel_size, padding=(kernel_size - 1) // 2) for _ in | |
| range(self.n_layers)]) # convolutional layers | |
| self.dropout = nn.Dropout(dropout) | |
| self.fc = nn.Linear(self.input_dim, self.hidden_dim) | |
| self.gn = nn.GroupNorm(8, hidden_dim * 2) | |
| self.ln = nn.LayerNorm(hidden_dim) | |
| def forward(self, protein): | |
| # pos = torch.arange(0, protein.shape[1]).unsqueeze(0).repeat(protein.shape[0], 1) | |
| # protein = protein + self.pos_embedding(pos) | |
| # protein = [batch size, protein len,protein_dim] | |
| conv_input = self.fc(protein.float()) | |
| # conv_input=[batch size,protein len,hid dim] | |
| # permute for convolutional layer | |
| conv_input = conv_input.permute(0, 2, 1) | |
| # conv_input = [batch size, hid dim, protein len] | |
| for i, conv in enumerate(self.convs): | |
| # pass through convolutional layer | |
| conved = conv(self.dropout(conv_input)) | |
| # conved = [batch size, 2*hid dim, protein len] | |
| # pass through GLU activation function | |
| conved = F.glu(conved, dim=1) | |
| # conved = [batch size, hid dim, protein len] | |
| # apply residual connection / high way | |
| conved = (conved + conv_input) * self.scale | |
| # conved = [batch size, hid dim, protein len] | |
| # set conv_input to conved for next loop iteration | |
| conv_input = conved | |
| conved = conved.permute(0, 2, 1) | |
| # conved = [batch size,protein len,hid dim] | |
| conved = self.ln(conved) | |
| return conved | |
| class PositionwiseFeedforward(nn.Module): | |
| def __init__(self, hidden_dim, pf_dim, dropout): | |
| super().__init__() | |
| self.hidden_dim = hidden_dim | |
| self.pf_dim = pf_dim | |
| self.fc_1 = nn.Conv1d(hidden_dim, pf_dim, 1) # convolution neural units | |
| self.fc_2 = nn.Conv1d(pf_dim, hidden_dim, 1) # convolution neural units | |
| self.do = nn.Dropout(dropout) | |
| def forward(self, x): | |
| # x = [batch size, sent len, hid dim] | |
| x = x.permute(0, 2, 1) # x = [batch size, hid dim, sent len] | |
| x = self.do(F.relu(self.fc_1(x))) # x = [batch size, pf dim, sent len] | |
| x = self.fc_2(x) # x = [batch size, hid dim, sent len] | |
| x = x.permute(0, 2, 1) # x = [batch size, sent len, hid dim] | |
| return x | |
| class DecoderLayer(nn.Module): | |
| def __init__(self, hidden_dim, n_heads, pf_dim, dropout, | |
| self_attention=SelfAttention, | |
| positionwise_feedforward=PositionwiseFeedforward): | |
| super().__init__() | |
| self.ln = nn.LayerNorm(hidden_dim) | |
| self.sa = self_attention(hidden_dim, n_heads, dropout) | |
| self.ea = self_attention(hidden_dim, n_heads, dropout) | |
| self.pf = positionwise_feedforward(hidden_dim, pf_dim, dropout) | |
| self.do = nn.Dropout(dropout) | |
| def forward(self, trg, src, trg_mask=None, src_mask=None): | |
| # trg = [batch_size, compound len, atom_dim] | |
| # src = [batch_size, protein len, hidden_dim] # encoder output | |
| # trg_mask = [batch size, compound sent len] | |
| # src_mask = [batch size, protein len] | |
| trg = self.ln(trg + self.do(self.sa(trg, trg, trg, trg_mask))) | |
| trg = self.ln(trg + self.do(self.ea(trg, src, src, src_mask))) | |
| trg = self.ln(trg + self.do(self.pf(trg))) | |
| return trg | |
| class Decoder(nn.Module): | |
| """ compound feature extraction.""" | |
| def __init__(self, atom_dim, hidden_dim, n_layers, n_heads, pf_dim, dropout, | |
| decoder_layer=DecoderLayer, | |
| self_attention=SelfAttention, | |
| positionwise_feedforward=PositionwiseFeedforward): | |
| super().__init__() | |
| self.ln = nn.LayerNorm(hidden_dim) | |
| self.output_dim = atom_dim | |
| self.hidden_dim = hidden_dim | |
| self.n_layers = n_layers | |
| self.n_heads = n_heads | |
| self.pf_dim = pf_dim | |
| self.decoder_layer = decoder_layer | |
| self.self_attention = self_attention | |
| self.positionwise_feedforward = positionwise_feedforward | |
| self.dropout = dropout | |
| self.sa = self_attention(hidden_dim, n_heads, dropout) | |
| self.layers = nn.ModuleList( | |
| [decoder_layer(hidden_dim, n_heads, pf_dim, dropout, self_attention, positionwise_feedforward) | |
| for _ in range(n_layers)]) | |
| self.ft = nn.Linear(atom_dim, hidden_dim) | |
| self.do = nn.Dropout(dropout) | |
| self.fc_1 = nn.Linear(hidden_dim, 256) | |
| # self.fc_2 = nn.Linear(256, 2) | |
| self.gn = nn.GroupNorm(8, 256) | |
| def forward(self, trg, src, trg_mask=None, src_mask=None): | |
| # trg = [batch_size, compound len, atom_dim] | |
| # src = [batch_size, protein len, hidden_dim] # encoder output | |
| trg = self.ft(trg) # trg = [batch size, compound len, hid dim] | |
| for layer in self.layers: | |
| trg = layer(trg, src, trg_mask, src_mask) # trg = [batch size, compound len, hid dim] | |
| """Use norm to determine which atom is significant. """ | |
| norm = torch.norm(trg, dim=2) # norm = [batch size,compound len] | |
| norm = F.softmax(norm, dim=1) # norm = [batch size,compound len] | |
| # trg = torch.squeeze(trg,dim=0) | |
| # norm = torch.squeeze(norm,dim=0) | |
| sum = torch.zeros((trg.shape[0], self.hidden_dim), device=trg.device) | |
| for i in range(norm.shape[0]): | |
| for j in range(norm.shape[1]): | |
| v = trg[i, j,] | |
| v = v * norm[i, j] | |
| sum[i,] += v # sum = [batch size,hidden_dim] | |
| label = F.relu(self.fc_1(sum)) | |
| # label = self.fc_2(label) | |
| return label | |