| | import pdb |
| |
|
| | import numpy as np |
| | import torch |
| | import torch.nn as nn |
| | from .layers import * |
| | from .modules import * |
| | import pdb |
| | from transformers import EsmModel, EsmTokenizer |
| |
|
| | def to_var(x): |
| | if torch.cuda.is_available(): |
| | x = x.cuda() |
| | return x |
| |
|
| |
|
| | class RepeatedModule3(nn.Module): |
| | def __init__(self, n_layers, d_model, d_hidden, |
| | n_head, d_k, d_v, d_inner, dropout=0.1): |
| | super().__init__() |
| |
|
| | self.linear1 = nn.Linear(1280, d_model) |
| | self.linear2 = nn.Linear(1280, d_model) |
| | self.sequence_embedding = nn.Embedding(20, d_model) |
| | self.d_model = d_model |
| |
|
| | self.reciprocal_layer_stack = nn.ModuleList([ |
| | ReciprocalLayerwithCNN(d_model, d_inner, d_hidden, n_head, d_k, d_v) |
| | for _ in range(n_layers)]) |
| |
|
| | self.dropout = nn.Dropout(dropout) |
| | self.dropout_2 = nn.Dropout(dropout) |
| |
|
| | def forward(self, peptide_sequence, protein_sequence): |
| | sequence_attention_list = [] |
| |
|
| | prot_attention_list = [] |
| |
|
| | prot_seq_attention_list = [] |
| |
|
| | seq_prot_attention_list = [] |
| |
|
| | sequence_enc = self.dropout(self.linear1(peptide_sequence)) |
| |
|
| | prot_enc = self.dropout_2(self.linear2(protein_sequence)) |
| |
|
| | for reciprocal_layer in self.reciprocal_layer_stack: |
| | prot_enc, sequence_enc, prot_attention, sequence_attention, prot_seq_attention, seq_prot_attention = \ |
| | reciprocal_layer(sequence_enc, prot_enc) |
| |
|
| | sequence_attention_list.append(sequence_attention) |
| |
|
| | prot_attention_list.append(prot_attention) |
| |
|
| | prot_seq_attention_list.append(prot_seq_attention) |
| |
|
| | seq_prot_attention_list.append(seq_prot_attention) |
| |
|
| | return prot_enc, sequence_enc, sequence_attention_list, prot_attention_list, \ |
| | seq_prot_attention_list, seq_prot_attention_list |
| |
|
| |
|
| | class RepeatedModule2(nn.Module): |
| | def __init__(self, n_layers, d_model, |
| | n_head, d_k, d_v, d_inner, dropout=0.1): |
| | super().__init__() |
| |
|
| | self.linear1 = nn.Linear(1280, d_model) |
| | self.linear2 = nn.Linear(1280, d_model) |
| | self.sequence_embedding = nn.Embedding(20, d_model) |
| | self.d_model = d_model |
| |
|
| | self.reciprocal_layer_stack = nn.ModuleList([ |
| | ReciprocalLayer(d_model, d_inner, n_head, d_k, d_v) |
| | for _ in range(n_layers)]) |
| |
|
| | self.dropout = nn.Dropout(dropout) |
| | self.dropout_2 = nn.Dropout(dropout) |
| |
|
| | def forward(self, peptide_sequence, protein_sequence): |
| | sequence_attention_list = [] |
| |
|
| | prot_attention_list = [] |
| |
|
| | prot_seq_attention_list = [] |
| |
|
| | seq_prot_attention_list = [] |
| |
|
| | sequence_enc = self.dropout(self.linear1(peptide_sequence)) |
| |
|
| | prot_enc = self.dropout_2(self.linear2(protein_sequence)) |
| |
|
| | for reciprocal_layer in self.reciprocal_layer_stack: |
| | prot_enc, sequence_enc, prot_attention, sequence_attention, prot_seq_attention, seq_prot_attention = \ |
| | reciprocal_layer(sequence_enc, prot_enc) |
| |
|
| | sequence_attention_list.append(sequence_attention) |
| |
|
| | prot_attention_list.append(prot_attention) |
| |
|
| | prot_seq_attention_list.append(prot_seq_attention) |
| |
|
| | seq_prot_attention_list.append(seq_prot_attention) |
| |
|
| | return prot_enc, sequence_enc, sequence_attention_list, prot_attention_list, \ |
| | seq_prot_attention_list, seq_prot_attention_list |
| |
|
| |
|
| | class RepeatedModule(nn.Module): |
| | |
| | def __init__(self, n_layers, d_model, |
| | n_head, d_k, d_v, d_inner, dropout=0.1): |
| | |
| | super().__init__() |
| | |
| | self.linear = nn.Linear(1024, d_model) |
| | self.sequence_embedding = nn.Embedding(20, d_model) |
| | self.d_model = d_model |
| | |
| | self.reciprocal_layer_stack = nn.ModuleList([ |
| | ReciprocalLayer(d_model, d_inner, n_head, d_k, d_v) |
| | for _ in range(n_layers)]) |
| | |
| | self.dropout = nn.Dropout(dropout) |
| | self.dropout_2 = nn.Dropout(dropout) |
| |
|
| | |
| | |
| | def _positional_embedding(self, batches, number): |
| | |
| | result = torch.exp(torch.arange(0, self.d_model,2,dtype=torch.float32)*-1*(np.log(10000)/self.d_model)) |
| | |
| | numbers = torch.arange(0, number, dtype=torch.float32) |
| | |
| | numbers = numbers.unsqueeze(0) |
| | |
| | numbers = numbers.unsqueeze(2) |
| | |
| | result = numbers*result |
| | |
| | result = torch.cat((torch.sin(result), torch.cos(result)),2) |
| | |
| | return result |
| | |
| | def forward(self, peptide_sequence, protein_sequence): |
| | |
| | |
| | sequence_attention_list = [] |
| | |
| | prot_attention_list = [] |
| | |
| | prot_seq_attention_list = [] |
| | |
| | seq_prot_attention_list = [] |
| | |
| | sequence_enc = self.sequence_embedding(peptide_sequence) |
| | |
| | sequence_enc += to_var(self._positional_embedding(peptide_sequence.shape[0], |
| | peptide_sequence.shape[1])) |
| | sequence_enc = self.dropout(sequence_enc) |
| |
|
| |
|
| |
|
| |
|
| |
|
| | prot_enc = self.dropout_2(self.linear(protein_sequence)) |
| | |
| | |
| | |
| |
|
| | for reciprocal_layer in self.reciprocal_layer_stack: |
| | |
| | prot_enc, sequence_enc, prot_attention, sequence_attention, prot_seq_attention, seq_prot_attention =\ |
| | reciprocal_layer(sequence_enc, prot_enc) |
| | |
| | sequence_attention_list.append(sequence_attention) |
| | |
| | prot_attention_list.append(prot_attention) |
| | |
| | prot_seq_attention_list.append(prot_seq_attention) |
| | |
| | seq_prot_attention_list.append(seq_prot_attention) |
| | |
| | |
| | |
| | return prot_enc, sequence_enc, sequence_attention_list, prot_attention_list,\ |
| | seq_prot_attention_list, seq_prot_attention_list |
| |
|
| |
|
| | class FullModel(nn.Module): |
| |
|
| | def __init__(self, n_layers, d_model, n_head, |
| | d_k, d_v, d_inner, return_attention=False, dropout=0.2): |
| | super().__init__() |
| |
|
| | self.esm_model = EsmModel.from_pretrained("facebook/esm2_t33_650M_UR50D") |
| |
|
| | |
| | for param in self.esm_model.parameters(): |
| | param.requires_grad = False |
| |
|
| | self.repeated_module = RepeatedModule2(n_layers, d_model, |
| | n_head, d_k, d_v, d_inner, dropout=dropout) |
| |
|
| | self.final_attention_layer = MultiHeadAttentionSequence(n_head, d_model, |
| | d_k, d_v, dropout=dropout) |
| |
|
| | self.final_ffn = FFN(d_model, d_inner, dropout=dropout) |
| |
|
| | self.output_projection_prot = nn.Linear(d_model, 1) |
| | self.sigmoid = nn.Sigmoid() |
| |
|
| | self.return_attention = return_attention |
| |
|
| | def forward(self, binder_tokens, target_tokens): |
| |
|
| | with torch.no_grad(): |
| | peptide_sequence = self.esm_model(**binder_tokens).last_hidden_state |
| | protein_sequence = self.esm_model(**target_tokens).last_hidden_state |
| |
|
| | |
| |
|
| | prot_enc, sequence_enc, sequence_attention_list, prot_attention_list, \ |
| | seq_prot_attention_list, seq_prot_attention_list = self.repeated_module(peptide_sequence, |
| | protein_sequence) |
| |
|
| | prot_enc, final_prot_seq_attention = self.final_attention_layer(prot_enc, sequence_enc, sequence_enc) |
| |
|
| | |
| |
|
| | prot_enc = self.final_ffn(prot_enc) |
| |
|
| | prot_enc = self.sigmoid(self.output_projection_prot(prot_enc)) |
| |
|
| | return prot_enc |
| |
|
| |
|
| |
|
| | class Original_FullModel(nn.Module): |
| | |
| | def __init__(self, n_layers, d_model, n_head, |
| | d_k, d_v, d_inner, return_attention=False, dropout=0.2): |
| | |
| | super().__init__() |
| | self.repeated_module = RepeatedModule(n_layers, d_model, |
| | n_head, d_k, d_v, d_inner, dropout=dropout) |
| | |
| | self.final_attention_layer = MultiHeadAttentionSequence(n_head, d_model, |
| | d_k, d_v, dropout=dropout) |
| | |
| | self.final_ffn = FFN(d_model, d_inner, dropout=dropout) |
| | self.output_projection_prot = nn.Linear(d_model, 2) |
| | |
| | |
| | |
| | self.softmax_prot =nn.LogSoftmax(dim=-1) |
| | |
| | |
| | self.return_attention = return_attention |
| | |
| | def forward(self, peptide_sequence, protein_sequence): |
| |
|
| | prot_enc, sequence_enc, sequence_attention_list, prot_attention_list,\ |
| | seq_prot_attention_list, seq_prot_attention_list = self.repeated_module(peptide_sequence, |
| | protein_sequence) |
| | |
| | |
| | |
| | prot_enc, final_prot_seq_attention = self.final_attention_layer(prot_enc, sequence_enc, sequence_enc) |
| | |
| | prot_enc = self.final_ffn(prot_enc) |
| |
|
| | prot_enc = self.softmax_prot(self.output_projection_prot(prot_enc)) |
| | |
| | |
| | |
| | |
| | |
| | if not self.return_attention: |
| | return prot_enc |
| | else: |
| | return prot_enc, sequence_attention_list, prot_attention_list,\ |
| | seq_prot_attention_list, seq_prot_attention_list |
| | |
| |
|