AVP-Pro / model_structure.py
Wwwy1031's picture
Update model_structure.py
e5c3ed3 verified
import torch
import torch.nn as nn
import torch.nn.functional as F
class SelfAttention(nn.Module):
def __init__(self, feature_dim):
super(SelfAttention, self).__init__()
self.query = nn.Linear(feature_dim, feature_dim)
self.key = nn.Linear(feature_dim, feature_dim)
self.value = nn.Linear(feature_dim, feature_dim)
self.softmax = nn.Softmax(dim=-1)
def forward(self, x):
q = self.query(x)
k = self.key(x)
v = self.value(x)
attention_scores = torch.bmm(q, k.transpose(1, 2))
attention_weights = self.softmax(attention_scores / (k.size(-1) ** 0.5))
weighted_features = torch.bmm(attention_weights, v)
return torch.mean(weighted_features, dim=1)
class ParallelFeatureExtractorWithAttention(nn.Module):
def __init__(self, input_dim, cnn_out_channels, lstm_hidden_dim, dropout_rate=0.3):
super(ParallelFeatureExtractorWithAttention, self).__init__()
self.cnn = nn.Conv1d(in_channels=input_dim, out_channels=cnn_out_channels, kernel_size=3, padding=1)
self.cnn_attention = SelfAttention(cnn_out_channels)
self.cnn_branch_output_dim = cnn_out_channels
self.bilstm = nn.LSTM(
input_size=input_dim,
hidden_size=lstm_hidden_dim,
num_layers=2,
batch_first=True,
bidirectional=True,
dropout=dropout_rate
)
self.bilstm_attention = SelfAttention(lstm_hidden_dim * 2)
self.bilstm_branch_output_dim = lstm_hidden_dim * 2
def forward(self, sequence_embedding):
cnn_in = sequence_embedding.permute(0, 2, 1)
cnn_out = self.cnn(cnn_in)
cnn_out = F.relu(cnn_out)
cnn_out_permuted = cnn_out.permute(0, 2, 1)
v_cnn = self.cnn_attention(cnn_out_permuted)
lstm_out, _ = self.bilstm(sequence_embedding)
v_bilstm = self.bilstm_attention(lstm_out)
return v_cnn, v_bilstm
class AVP_Fusion(nn.Module):
def __init__(self, esm_dim, additional_dim, cnn_out_channels, lstm_hidden_dim, num_classes, dropout_rate=0.42):
super(AVP_Fusion, self).__init__()
fused_input_dim = esm_dim + additional_dim
self.parallel_extractor = ParallelFeatureExtractorWithAttention(fused_input_dim, cnn_out_channels, lstm_hidden_dim, dropout_rate)
cnn_feature_dim = self.parallel_extractor.cnn_branch_output_dim
bilstm_feature_dim = self.parallel_extractor.bilstm_branch_output_dim
self.gating_network = nn.Sequential(
nn.Linear(cnn_feature_dim + bilstm_feature_dim, 1),
nn.Sigmoid()
)
self.cnn_dim_matcher = nn.Linear(cnn_feature_dim, bilstm_feature_dim)
classifier_input_dim = bilstm_feature_dim
self.classifier = nn.Sequential(
nn.Linear(classifier_input_dim, classifier_input_dim // 2),
nn.BatchNorm1d(classifier_input_dim // 2),
nn.ReLU(),
nn.Dropout(dropout_rate),
nn.Linear(classifier_input_dim // 2, num_classes)
)
self.embedding_dim = classifier_input_dim
def forward(self, esm_sequence_embedding, additional_features):
seq_len = esm_sequence_embedding.size(1)
expanded_additional_features = additional_features.unsqueeze(1).expand(-1, seq_len, -1)
fused_sequence_embedding = torch.cat([esm_sequence_embedding, expanded_additional_features], dim=2)
v_cnn, v_bilstm = self.parallel_extractor(fused_sequence_embedding)
v_cnn_matched = self.cnn_dim_matcher(v_cnn)
lambda_gate = self.gating_network(torch.cat([v_cnn, v_bilstm], dim=1))
final_embedding = lambda_gate * v_cnn_matched + (1 - lambda_gate) * v_bilstm
logits = self.classifier(final_embedding)
return logits, final_embedding