| import pdb | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| from .layers import * | |
| from .modules import * | |
| import pdb | |
| from transformers import EsmModel, EsmTokenizer | |
| def to_var(x): | |
| if torch.cuda.is_available(): | |
| x = x.cuda() | |
| return x | |
| class RepeatedModule3(nn.Module): | |
| def __init__(self, n_layers, d_model, d_hidden, | |
| n_head, d_k, d_v, d_inner, dropout=0.1): | |
| super().__init__() | |
| self.linear1 = nn.Linear(1280, d_model) | |
| self.linear2 = nn.Linear(1280, d_model) | |
| self.sequence_embedding = nn.Embedding(20, d_model) | |
| self.d_model = d_model | |
| self.reciprocal_layer_stack = nn.ModuleList([ | |
| ReciprocalLayerwithCNN(d_model, d_inner, d_hidden, n_head, d_k, d_v) | |
| for _ in range(n_layers)]) | |
| self.dropout = nn.Dropout(dropout) | |
| self.dropout_2 = nn.Dropout(dropout) | |
| def forward(self, peptide_sequence, protein_sequence): | |
| sequence_attention_list = [] | |
| prot_attention_list = [] | |
| prot_seq_attention_list = [] | |
| seq_prot_attention_list = [] | |
| sequence_enc = self.dropout(self.linear1(peptide_sequence)) | |
| prot_enc = self.dropout_2(self.linear2(protein_sequence)) | |
| for reciprocal_layer in self.reciprocal_layer_stack: | |
| prot_enc, sequence_enc, prot_attention, sequence_attention, prot_seq_attention, seq_prot_attention = \ | |
| reciprocal_layer(sequence_enc, prot_enc) | |
| sequence_attention_list.append(sequence_attention) | |
| prot_attention_list.append(prot_attention) | |
| prot_seq_attention_list.append(prot_seq_attention) | |
| seq_prot_attention_list.append(seq_prot_attention) | |
| return prot_enc, sequence_enc, sequence_attention_list, prot_attention_list, \ | |
| seq_prot_attention_list, seq_prot_attention_list | |
| class RepeatedModule2(nn.Module): | |
| def __init__(self, n_layers, d_model, | |
| n_head, d_k, d_v, d_inner, dropout=0.1): | |
| super().__init__() | |
| self.linear1 = nn.Linear(1280, d_model) | |
| self.linear2 = nn.Linear(1280, d_model) | |
| self.sequence_embedding = nn.Embedding(20, d_model) | |
| self.d_model = d_model | |
| self.reciprocal_layer_stack = nn.ModuleList([ | |
| ReciprocalLayer(d_model, d_inner, n_head, d_k, d_v) | |
| for _ in range(n_layers)]) | |
| self.dropout = nn.Dropout(dropout) | |
| self.dropout_2 = nn.Dropout(dropout) | |
| def forward(self, peptide_sequence, protein_sequence): | |
| sequence_attention_list = [] | |
| prot_attention_list = [] | |
| prot_seq_attention_list = [] | |
| seq_prot_attention_list = [] | |
| sequence_enc = self.dropout(self.linear1(peptide_sequence)) | |
| prot_enc = self.dropout_2(self.linear2(protein_sequence)) | |
| for reciprocal_layer in self.reciprocal_layer_stack: | |
| prot_enc, sequence_enc, prot_attention, sequence_attention, prot_seq_attention, seq_prot_attention = \ | |
| reciprocal_layer(sequence_enc, prot_enc) | |
| sequence_attention_list.append(sequence_attention) | |
| prot_attention_list.append(prot_attention) | |
| prot_seq_attention_list.append(prot_seq_attention) | |
| seq_prot_attention_list.append(seq_prot_attention) | |
| return prot_enc, sequence_enc, sequence_attention_list, prot_attention_list, \ | |
| seq_prot_attention_list, seq_prot_attention_list | |
| class RepeatedModule(nn.Module): | |
| def __init__(self, n_layers, d_model, | |
| n_head, d_k, d_v, d_inner, dropout=0.1): | |
| super().__init__() | |
| self.linear = nn.Linear(1024, d_model) | |
| self.sequence_embedding = nn.Embedding(20, d_model) | |
| self.d_model = d_model | |
| self.reciprocal_layer_stack = nn.ModuleList([ | |
| ReciprocalLayer(d_model, d_inner, n_head, d_k, d_v) | |
| for _ in range(n_layers)]) | |
| self.dropout = nn.Dropout(dropout) | |
| self.dropout_2 = nn.Dropout(dropout) | |
| def _positional_embedding(self, batches, number): | |
| result = torch.exp(torch.arange(0, self.d_model,2,dtype=torch.float32)*-1*(np.log(10000)/self.d_model)) | |
| numbers = torch.arange(0, number, dtype=torch.float32) | |
| numbers = numbers.unsqueeze(0) | |
| numbers = numbers.unsqueeze(2) | |
| result = numbers*result | |
| result = torch.cat((torch.sin(result), torch.cos(result)),2) | |
| return result | |
| def forward(self, peptide_sequence, protein_sequence): | |
| sequence_attention_list = [] | |
| prot_attention_list = [] | |
| prot_seq_attention_list = [] | |
| seq_prot_attention_list = [] | |
| sequence_enc = self.sequence_embedding(peptide_sequence) | |
| sequence_enc += to_var(self._positional_embedding(peptide_sequence.shape[0], | |
| peptide_sequence.shape[1])) | |
| sequence_enc = self.dropout(sequence_enc) | |
| prot_enc = self.dropout_2(self.linear(protein_sequence)) | |
| for reciprocal_layer in self.reciprocal_layer_stack: | |
| prot_enc, sequence_enc, prot_attention, sequence_attention, prot_seq_attention, seq_prot_attention =\ | |
| reciprocal_layer(sequence_enc, prot_enc) | |
| sequence_attention_list.append(sequence_attention) | |
| prot_attention_list.append(prot_attention) | |
| prot_seq_attention_list.append(prot_seq_attention) | |
| seq_prot_attention_list.append(seq_prot_attention) | |
| return prot_enc, sequence_enc, sequence_attention_list, prot_attention_list,\ | |
| seq_prot_attention_list, seq_prot_attention_list | |
| class FullModel(nn.Module): | |
| def __init__(self, n_layers, d_model, n_head, | |
| d_k, d_v, d_inner, return_attention=False, dropout=0.2): | |
| super().__init__() | |
| self.esm_model = EsmModel.from_pretrained("facebook/esm2_t33_650M_UR50D") | |
| # freeze all the esm_model parameters | |
| for param in self.esm_model.parameters(): | |
| param.requires_grad = False | |
| self.repeated_module = RepeatedModule2(n_layers, d_model, | |
| n_head, d_k, d_v, d_inner, dropout=dropout) | |
| self.final_attention_layer = MultiHeadAttentionSequence(n_head, d_model, | |
| d_k, d_v, dropout=dropout) | |
| self.final_ffn = FFN(d_model, d_inner, dropout=dropout) | |
| self.output_projection_prot = nn.Linear(d_model, 1) | |
| self.sigmoid = nn.Sigmoid() | |
| self.return_attention = return_attention | |
| def forward(self, binder_tokens, target_tokens): | |
| with torch.no_grad(): | |
| peptide_sequence = self.esm_model(**binder_tokens).last_hidden_state | |
| protein_sequence = self.esm_model(**target_tokens).last_hidden_state | |
| # pdb.set_trace() | |
| prot_enc, sequence_enc, sequence_attention_list, prot_attention_list, \ | |
| seq_prot_attention_list, seq_prot_attention_list = self.repeated_module(peptide_sequence, | |
| protein_sequence) | |
| prot_enc, final_prot_seq_attention = self.final_attention_layer(prot_enc, sequence_enc, sequence_enc) | |
| # pdb.set_trace() | |
| prot_enc = self.final_ffn(prot_enc) | |
| prot_enc = self.sigmoid(self.output_projection_prot(prot_enc)) | |
| return prot_enc | |
| class Original_FullModel(nn.Module): | |
| def __init__(self, n_layers, d_model, n_head, | |
| d_k, d_v, d_inner, return_attention=False, dropout=0.2): | |
| super().__init__() | |
| self.repeated_module = RepeatedModule(n_layers, d_model, | |
| n_head, d_k, d_v, d_inner, dropout=dropout) | |
| self.final_attention_layer = MultiHeadAttentionSequence(n_head, d_model, | |
| d_k, d_v, dropout=dropout) | |
| self.final_ffn = FFN(d_model, d_inner, dropout=dropout) | |
| self.output_projection_prot = nn.Linear(d_model, 2) | |
| self.softmax_prot =nn.LogSoftmax(dim=-1) | |
| self.return_attention = return_attention | |
| def forward(self, peptide_sequence, protein_sequence): | |
| prot_enc, sequence_enc, sequence_attention_list, prot_attention_list,\ | |
| seq_prot_attention_list, seq_prot_attention_list = self.repeated_module(peptide_sequence, | |
| protein_sequence) | |
| prot_enc, final_prot_seq_attention = self.final_attention_layer(prot_enc, sequence_enc, sequence_enc) | |
| prot_enc = self.final_ffn(prot_enc) | |
| prot_enc = self.softmax_prot(self.output_projection_prot(prot_enc)) | |
| if not self.return_attention: | |
| return prot_enc | |
| else: | |
| return prot_enc, sequence_attention_list, prot_attention_list,\ | |
| seq_prot_attention_list, seq_prot_attention_list | |
Xet Storage Details
- Size:
- 9.38 kB
- Xet hash:
- 5d2893817106f57ab83597c41d47104632040d11af1a7fb848ab6c95dc49ffea
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.