DeepPD-hf / DeepPD /model.py
xiaoleon's picture
initial submission
46b9840
import torch
import torch.nn as nn
import torch.nn.functional as F
from DeepPD.utils import CBAMBlock,Res_Net
from DeepPD.data_helper import Numseq2OneHot
from transformers import BertModel
bert_wight = BertModel.from_pretrained("./DeepPD/BERT")
class MyModel(nn.Module):
def __init__(self):
super().__init__()
batch_size = 64
vocab_size = 21
self.hidden_dim = 25
self.gru_emb = 128
self.emb_dim = 108
self.model = bert_wight
self.gru = nn.GRU(self.gru_emb, self.hidden_dim, num_layers=2,
bidirectional=True,dropout=0.1)
self.embedding = nn.Embedding(vocab_size, self.emb_dim, padding_idx=0)
self.encoder_layer = nn.TransformerEncoderLayer(d_model=128, nhead=8)
self.transformer_encoder = nn.TransformerEncoder(self.encoder_layer, num_layers=1)
self.resnet = Res_Net(batch_size)
self.cbamBlock = CBAMBlock(batch_size)
self.convblock1 = nn.Sequential(
nn.Conv2d(1,batch_size,1),
nn.BatchNorm2d(batch_size),
nn.LeakyReLU()
)
self.convblock2 = nn.Sequential(
nn.Conv2d(batch_size,1,1),
nn.BatchNorm2d(1),
nn.LeakyReLU()
)
self.fc = nn.Sequential( nn.Linear(4200,512),
nn.BatchNorm1d(512),
nn.LeakyReLU(),
nn.Linear(512,32),
nn.BatchNorm1d(32),
nn.LeakyReLU(),
nn.Linear(32,2))
def forward(self, x):
xx = self.embedding(x) #* 40 128 #* 40 108
z = Numseq2OneHot(x) #* 40 20
z = z.type_as(xx)
out = torch.cat([xx,z],2)
out = self.transformer_encoder(out)
out = out.unsqueeze(1)
out = self.convblock1(out) #*,32,40,128
out = self.resnet(out)
out = self.resnet(out)
out = self.cbamBlock(out)
out = self.convblock2(out) #*,1,40,128
out = out.squeeze(1)
out = out.permute(1,0,2) #40,*,128
out,hn = self.gru(out)
out = out.permute(1,0,2) #*,40,50
hn = hn.permute(1,0,2) #*,4,25
out = out.reshape(out.shape[0],-1) #* 900
hn = hn.reshape(hn.shape[0],-1) #* 100
out = torch.cat([out,hn],1) #* 1000
out1 = self.model(x)[0] #*,40,128
out1 = out1.permute(1,0,2) #40,*,128
out1,hn1 = self.gru(out1)
out1 = out1.permute(1,0,2) #*,40,50
hn1= hn1.permute(1,0,2) #*,4,25
out1 = out1.reshape(out1.shape[0],-1) #* 2000
hn1 = hn1.reshape(hn1.shape[0],-1) #* 100
out1 = torch.cat([out1,hn1],1) #* 2100
out = torch.cat([out1,out],1) #* 4200
out = self.fc(out)
return out
from DeepPD.utils_etfc import *
import torch,esm
import torch.nn as nn
from DeepPD.data_helper import index_alignment,seqs2blosum62
import torch.nn.functional as f
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class DeepPD(nn.Module):
def __init__(self, vocab_size:int, embedding_size:int, fan_layer_num:int, num_heads:int,encoder_layer_num:int=1,seq_len: int=40,
output_size:int=2, layer_idx=None,esm_path=None,dropout:float=0.6, max_pool: int=4,Contrastive_Learning=False,info_bottleneck=False):
super(DeepPD, self).__init__()
self.vocab_size = vocab_size
self.embedding_size = embedding_size
self.output_size = output_size
self.seq_len = seq_len
self.dropout = dropout
self.dropout_layer = nn.Dropout(self.dropout)
self.encoder_layer_num = encoder_layer_num
self.fan_layer_num = fan_layer_num
self.num_heads = num_heads
self.max_pool = max_pool
self.ctl = Contrastive_Learning
self.info_bottleneck = info_bottleneck
self.ESMmodel,_ = esm.pretrained.load_model_and_alphabet_local(esm_path)
self.ESMmodel.eval()
self.layer_idx = layer_idx
self.out_chs = 64
self.kernel_sizes = [3,7]
self.all_conv = nn.ModuleList([
nn.Sequential(
nn.Conv1d(self.embedding_size+20,out_channels=self.out_chs,kernel_size=self.kernel_sizes[i],padding=(self.kernel_sizes[i]-1)//2), #padding=(self.kernel_sizes[i]-1)//2,
nn.BatchNorm1d(self.out_chs),
nn.LeakyReLU()
)
for i in range(len(self.kernel_sizes))
])
self.hidden_dim = 64
self.gru = nn.GRU(self.out_chs*2, self.hidden_dim, num_layers=2, batch_first=True,
bidirectional=True,dropout=0.25)
self.embed = nn.Embedding(self.vocab_size, self.embedding_size)
# self.encoder_layer = nn.TransformerEncoderLayer(d_model=self.embedding_size,nhead=self.num_heads,dropout=self.dropout)
# self.transformer_encoder = nn.TransformerEncoder(self.encoder_layer, num_layers=1)
# self.MaxPool1d = nn.MaxPool1d(kernel_size=self.max_pool) # stride的默认值=kernel_size
self.pos_encoding = PositionalEncoding(num_hiddens=self.embedding_size,dropout=self.dropout)
self.attention_encode = AttentionEncode(self.dropout, self.embedding_size, self.num_heads,seq_len=self.seq_len,ffn=False)
shape = int(40*(64*2+64)) # +64
# self.fan = FAN_encode(self.dropout, shape)
z_dim = 1024
self.enc_mean = nn.Linear(shape,z_dim)
self.enc_std = nn.Linear(shape,z_dim)
self.dec = nn.Sequential(
nn.Linear(z_dim,128),
nn.BatchNorm1d(128),
nn.LeakyReLU(),
nn.Linear(128,self.output_size)
)
self.proj_layer = nn.Linear(self.embedding_size,self.out_chs)
self.fc = nn.Sequential(
nn.Linear(shape,z_dim),
nn.BatchNorm1d(z_dim),
nn.LeakyReLU(),
nn.Linear(z_dim,128),
nn.BatchNorm1d(128),
nn.LeakyReLU(),
nn.Linear(128,self.output_size)
)
def CNN1DNet(self,x):
for i in range(len(self.kernel_sizes)):
conv = self.all_conv[i]
conv_x = conv(x)
# conv_x = self.MaxPool1d(conv_x)
if i == 0:
all_feats = conv_x
else:
all_feats = torch.cat([all_feats,conv_x],dim=1)
return all_feats
def forward(self, x):
# x : [B,S=40]
# get esm embedding
with torch.no_grad():
results = self.ESMmodel(x, repr_layers=[self.layer_idx], return_contacts=False)
esm_x = results["representations"][self.layer_idx] #* 50 480 /640 /1280 # [B,S,480]
x = index_alignment(x,condition_num=1,subtraction_num1=3,subtraction_num2=1)
# feature A
embed_x = self.embed(x) # [batch_size,seq_len,embedding_size] c
pos_x = self.pos_encoding(embed_x * math.sqrt(self.embedding_size)) # [batch_size,seq_len,embedding_size]
encoding_x = pos_x # [B,S,480]
for _ in range(self.encoder_layer_num):
encoding_x = self.attention_encode(encoding_x)
encoding_x += embed_x
featA = encoding_x + esm_x
# feature B
pssm = seqs2blosum62(x).to(device) # B,S,20
featB = pssm.type_as(embed_x)
featAB = torch.cat([featA,featB],dim=2) # B,S,480+20
cnn_input = featAB.permute(0, 2, 1) # B,H,S
cnn_output = self.CNN1DNet(cnn_input) # B,out_chs*2,S
out = self.dropout_layer(cnn_output)
# out = self.dropout_layer(featA)
out = out.permute(0,2,1) # B,S,H:out_chs*2
out,_ = self.gru(out)
out = self.dropout_layer(out)
final_featAB = out.reshape(x.size(0),-1) # B,S*H:40*hidden_dim(64)*2
# feature C
featC = self.proj_layer(esm_x)
featC = self.dropout_layer(featC)
featC = featC.reshape(featC.shape[0],-1)
feat = torch.cat([final_featAB,featC],1) # B
final_feat = self.dropout_layer(feat) # B,S*(64*2+64)
# final_feat = final_featAB
# final_feat = featC
if self.info_bottleneck:
# ToxIBTL prediction head
enc_mean, enc_std = self.enc_mean(final_feat), f.softplus(self.enc_std(final_feat)-5)
eps = torch.randn_like(enc_std)
IB_out = enc_mean + enc_std*eps
logits = self.dec(IB_out)
return logits,enc_mean,enc_std
# return featA,featB,featAB,final_featAB,featC,enc_mean
else:
# 全连接层
logits = self.fc(final_feat)
return logits,logits,logits
# return featA,featB,featAB,final_featAB,featC,logits