xiaoleon's picture
initial submission
2d48951
import torch
import esm,math
import torch.nn as nn
import torch.nn.functional as F
from DeepMFPP.utils import PositionalEncoding,FAN_encode
from DeepMFPP.data_helper import index_alignment
class DeepMFPP(nn.Module):
def __init__(self, vocab_size: int, embedding_size: int, fan_layer_num: int=1, num_heads: int=8, encoder_layer_num: int = 1,
output_size: int = 21, layer_idx=None, esm_path=None, dropout: float = 0.6, max_pool=5, Contrastive_Learning=False):
super(DeepMFPP,self).__init__()
self.vocab_size = vocab_size
self.embedding_size = embedding_size
self.output_size = output_size
self.dropout = dropout
self.dropout_layer = nn.Dropout(self.dropout)
self.encoder_layer_num = encoder_layer_num
self.fan_layer_num = fan_layer_num
self.num_heads = num_heads
self.max_pool = max_pool
self.ctl = Contrastive_Learning
self.ffn_size = self.embedding_size*2
self.dropout_layer1 = nn.Dropout(0.4)
self.ESMmodel,_ = esm.pretrained.load_model_and_alphabet_local(esm_path)
self.ESMmodel.eval()
self.layer_idx = layer_idx
self.out_chs = 64
final_feats_shape = self.out_chs*50
self.embedding = nn.Embedding(self.vocab_size, self.embedding_size)
self.pos_encoding = PositionalEncoding(num_hiddens=self.embedding_size,dropout=self.dropout)
# self.attention_encode = AttentionEncode(self.dropout, self.embedding_size, self.num_heads,ffn=False)
self.ffn = nn.Sequential(
nn.Linear(self.embedding_size, self.embedding_size*2, bias=True),
nn.GELU(),
# nn.LeakyReLU(),
nn.Linear(self.embedding_size*2, self.embedding_size, bias=True),
)
self.ln1 = nn.LayerNorm(self.embedding_size)
self.softmax = nn.Softmax(dim=-1)
self.W_o = nn.Linear(self.embedding_size,self.embedding_size)
self.kernel_sizes = [3,5,7,11,15]
self.MaxPool1d = nn.MaxPool1d(kernel_size=self.max_pool)
self.all_conv = nn.ModuleList([
nn.Sequential(
nn.Conv1d(self.embedding_size,out_channels=self.out_chs,kernel_size=self.kernel_sizes[i],padding=(self.kernel_sizes[i]-1)//2),
nn.BatchNorm1d(self.out_chs),
nn.LeakyReLU()
)
for i in range(len(self.kernel_sizes))
])
# self.project_layer =nn.Linear(self.embedding_size,64)
self.fan = FAN_encode(self.dropout, final_feats_shape)
self.proj_layer = nn.Sequential( nn.Linear(final_feats_shape,1280),
nn.BatchNorm1d(1280),
nn.LeakyReLU(),
nn.Linear(1280,128)
)
self.fc = nn.Sequential(
nn.BatchNorm1d(128),
nn.LeakyReLU(),
nn.Linear(128,self.output_size)
)
def CNN1DNet(self,x):
for i in range(len(self.kernel_sizes)):
conv = self.all_conv[i]
conv_x = conv(x)
conv_x = self.MaxPool1d(conv_x)
if i == 0:
all_feats = conv_x
else:
all_feats = torch.cat([all_feats,conv_x],dim=-1)
return all_feats
def forward(self, x):
B,S = x.shape
H = self.embedding_size
# --- ESM layer ----
with torch.no_grad():
results = self.ESMmodel(x, repr_layers=[self.layer_idx], return_contacts=False)
esm_x = results["representations"][self.layer_idx] #* 50 480 /640 /1280
# --- feature A Embedding+PE layer ----
index_ali_x = index_alignment(x,condition_num=1,subtraction_num1=3,subtraction_num2=1)
embedding_x = self.embedding(index_ali_x) # [batch_size,seq_len,embedding_size]
pos_x = self.pos_encoding(embedding_x * math.sqrt(self.embedding_size)) # [batch_size,seq_len,embedding_size]
feats1 = pos_x
# feats1 = embedding_x
# feats_fuse = feats1
# for _ in range(self.encoder_layer_num):
# feats1 = self.attention_encode(feats1)
# feats1 += embedding_x # B,S,H
# feats1 += esm_x
feats2 = esm_x
# feats_fuse = feats2
# # --- Self-attention feature fuse ---
d = feats1.size(-1)
q,k = feats1, feats2
v = feats1 + feats2 #+ esm_x
feats_qk = q @ k.transpose(-1, -2)*math.sqrt(d)
feats_qk = self.softmax(feats_qk)
feats_v = feats_qk @ v
# 线性变换投影到输出向量空间
feats_v = self.W_o(feats_v) # [B,S,H]
ffn_y = self.ffn(self.ln1(feats_v)) # 这两行的结构好像只能这样写
feats_fuse = v + self.dropout_layer(ffn_y)
# feats_fuse = feats1 + feats2
# feats_final = self.dropout_layer(self.project_layer(feats_fuse))
# # --- 1DCNN layer ---
cnn_input = feats_fuse
cnn_input = cnn_input.permute(0, 2, 1) # [B,H,S]
feats3 = self.CNN1DNet(cnn_input) # [B,F,S] F:out_chas
feats3 = self.dropout_layer(feats3)
feats_final = feats3
# --- FFN layer ---
fan_input = feats_final.view(x.size(0),-1) # B,seq_len*feat_dim:50*64
fan_input = fan_input.unsqueeze(1) # B,1,seq_len*feat_dim:50*64 AddNorm中的normalized=[1, shape]
for _ in range(self.fan_layer_num):
fan_encode = self.fan(fan_input)
fan_out = fan_encode.squeeze(1)
# fan_out = fan_input.squeeze(1)
# --- CLSFC layer ---
hidden = self.proj_layer(fan_out)
logits = self.fc(hidden)
# return feats1,feats2,feats_fuse,feats_final,fan_out,hidden,logits
return hidden,logits