Spaces:
Sleeping
Sleeping
| import torch | |
| import esm,math | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from DeepMFPP.utils import PositionalEncoding,FAN_encode | |
| from DeepMFPP.data_helper import index_alignment | |
| class DeepMFPP(nn.Module): | |
| def __init__(self, vocab_size: int, embedding_size: int, fan_layer_num: int=1, num_heads: int=8, encoder_layer_num: int = 1, | |
| output_size: int = 21, layer_idx=None, esm_path=None, dropout: float = 0.6, max_pool=5, Contrastive_Learning=False): | |
| super(DeepMFPP,self).__init__() | |
| self.vocab_size = vocab_size | |
| self.embedding_size = embedding_size | |
| self.output_size = output_size | |
| self.dropout = dropout | |
| self.dropout_layer = nn.Dropout(self.dropout) | |
| self.encoder_layer_num = encoder_layer_num | |
| self.fan_layer_num = fan_layer_num | |
| self.num_heads = num_heads | |
| self.max_pool = max_pool | |
| self.ctl = Contrastive_Learning | |
| self.ffn_size = self.embedding_size*2 | |
| self.dropout_layer1 = nn.Dropout(0.4) | |
| self.ESMmodel,_ = esm.pretrained.load_model_and_alphabet_local(esm_path) | |
| self.ESMmodel.eval() | |
| self.layer_idx = layer_idx | |
| self.out_chs = 64 | |
| final_feats_shape = self.out_chs*50 | |
| self.embedding = nn.Embedding(self.vocab_size, self.embedding_size) | |
| self.pos_encoding = PositionalEncoding(num_hiddens=self.embedding_size,dropout=self.dropout) | |
| # self.attention_encode = AttentionEncode(self.dropout, self.embedding_size, self.num_heads,ffn=False) | |
| self.ffn = nn.Sequential( | |
| nn.Linear(self.embedding_size, self.embedding_size*2, bias=True), | |
| nn.GELU(), | |
| # nn.LeakyReLU(), | |
| nn.Linear(self.embedding_size*2, self.embedding_size, bias=True), | |
| ) | |
| self.ln1 = nn.LayerNorm(self.embedding_size) | |
| self.softmax = nn.Softmax(dim=-1) | |
| self.W_o = nn.Linear(self.embedding_size,self.embedding_size) | |
| self.kernel_sizes = [3,5,7,11,15] | |
| self.MaxPool1d = nn.MaxPool1d(kernel_size=self.max_pool) | |
| self.all_conv = nn.ModuleList([ | |
| nn.Sequential( | |
| nn.Conv1d(self.embedding_size,out_channels=self.out_chs,kernel_size=self.kernel_sizes[i],padding=(self.kernel_sizes[i]-1)//2), | |
| nn.BatchNorm1d(self.out_chs), | |
| nn.LeakyReLU() | |
| ) | |
| for i in range(len(self.kernel_sizes)) | |
| ]) | |
| # self.project_layer =nn.Linear(self.embedding_size,64) | |
| self.fan = FAN_encode(self.dropout, final_feats_shape) | |
| self.proj_layer = nn.Sequential( nn.Linear(final_feats_shape,1280), | |
| nn.BatchNorm1d(1280), | |
| nn.LeakyReLU(), | |
| nn.Linear(1280,128) | |
| ) | |
| self.fc = nn.Sequential( | |
| nn.BatchNorm1d(128), | |
| nn.LeakyReLU(), | |
| nn.Linear(128,self.output_size) | |
| ) | |
| def CNN1DNet(self,x): | |
| for i in range(len(self.kernel_sizes)): | |
| conv = self.all_conv[i] | |
| conv_x = conv(x) | |
| conv_x = self.MaxPool1d(conv_x) | |
| if i == 0: | |
| all_feats = conv_x | |
| else: | |
| all_feats = torch.cat([all_feats,conv_x],dim=-1) | |
| return all_feats | |
| def forward(self, x): | |
| B,S = x.shape | |
| H = self.embedding_size | |
| # --- ESM layer ---- | |
| with torch.no_grad(): | |
| results = self.ESMmodel(x, repr_layers=[self.layer_idx], return_contacts=False) | |
| esm_x = results["representations"][self.layer_idx] #* 50 480 /640 /1280 | |
| # --- feature A Embedding+PE layer ---- | |
| index_ali_x = index_alignment(x,condition_num=1,subtraction_num1=3,subtraction_num2=1) | |
| embedding_x = self.embedding(index_ali_x) # [batch_size,seq_len,embedding_size] | |
| pos_x = self.pos_encoding(embedding_x * math.sqrt(self.embedding_size)) # [batch_size,seq_len,embedding_size] | |
| feats1 = pos_x | |
| # feats1 = embedding_x | |
| # feats_fuse = feats1 | |
| # for _ in range(self.encoder_layer_num): | |
| # feats1 = self.attention_encode(feats1) | |
| # feats1 += embedding_x # B,S,H | |
| # feats1 += esm_x | |
| feats2 = esm_x | |
| # feats_fuse = feats2 | |
| # # --- Self-attention feature fuse --- | |
| d = feats1.size(-1) | |
| q,k = feats1, feats2 | |
| v = feats1 + feats2 #+ esm_x | |
| feats_qk = q @ k.transpose(-1, -2)*math.sqrt(d) | |
| feats_qk = self.softmax(feats_qk) | |
| feats_v = feats_qk @ v | |
| # 线性变换投影到输出向量空间 | |
| feats_v = self.W_o(feats_v) # [B,S,H] | |
| ffn_y = self.ffn(self.ln1(feats_v)) # 这两行的结构好像只能这样写 | |
| feats_fuse = v + self.dropout_layer(ffn_y) | |
| # feats_fuse = feats1 + feats2 | |
| # feats_final = self.dropout_layer(self.project_layer(feats_fuse)) | |
| # # --- 1DCNN layer --- | |
| cnn_input = feats_fuse | |
| cnn_input = cnn_input.permute(0, 2, 1) # [B,H,S] | |
| feats3 = self.CNN1DNet(cnn_input) # [B,F,S] F:out_chas | |
| feats3 = self.dropout_layer(feats3) | |
| feats_final = feats3 | |
| # --- FFN layer --- | |
| fan_input = feats_final.view(x.size(0),-1) # B,seq_len*feat_dim:50*64 | |
| fan_input = fan_input.unsqueeze(1) # B,1,seq_len*feat_dim:50*64 AddNorm中的normalized=[1, shape] | |
| for _ in range(self.fan_layer_num): | |
| fan_encode = self.fan(fan_input) | |
| fan_out = fan_encode.squeeze(1) | |
| # fan_out = fan_input.squeeze(1) | |
| # --- CLSFC layer --- | |
| hidden = self.proj_layer(fan_out) | |
| logits = self.fc(hidden) | |
| # return feats1,feats2,feats_fuse,feats_final,fan_out,hidden,logits | |
| return hidden,logits |