import torch import torch.nn as nn import torch.nn.functional as F class SelfAttention(nn.Module): def __init__(self, feature_dim): super(SelfAttention, self).__init__() self.query = nn.Linear(feature_dim, feature_dim) self.key = nn.Linear(feature_dim, feature_dim) self.value = nn.Linear(feature_dim, feature_dim) self.softmax = nn.Softmax(dim=-1) def forward(self, x): q = self.query(x) k = self.key(x) v = self.value(x) attention_scores = torch.bmm(q, k.transpose(1, 2)) attention_weights = self.softmax(attention_scores / (k.size(-1) ** 0.5)) weighted_features = torch.bmm(attention_weights, v) return torch.mean(weighted_features, dim=1) class ParallelFeatureExtractorWithAttention(nn.Module): def __init__(self, input_dim, cnn_out_channels, lstm_hidden_dim, dropout_rate=0.3): super(ParallelFeatureExtractorWithAttention, self).__init__() self.cnn = nn.Conv1d(in_channels=input_dim, out_channels=cnn_out_channels, kernel_size=3, padding=1) self.cnn_attention = SelfAttention(cnn_out_channels) self.cnn_branch_output_dim = cnn_out_channels self.bilstm = nn.LSTM( input_size=input_dim, hidden_size=lstm_hidden_dim, num_layers=2, batch_first=True, bidirectional=True, dropout=dropout_rate ) self.bilstm_attention = SelfAttention(lstm_hidden_dim * 2) self.bilstm_branch_output_dim = lstm_hidden_dim * 2 def forward(self, sequence_embedding): cnn_in = sequence_embedding.permute(0, 2, 1) cnn_out = self.cnn(cnn_in) cnn_out = F.relu(cnn_out) cnn_out_permuted = cnn_out.permute(0, 2, 1) v_cnn = self.cnn_attention(cnn_out_permuted) lstm_out, _ = self.bilstm(sequence_embedding) v_bilstm = self.bilstm_attention(lstm_out) return v_cnn, v_bilstm class AVP_Fusion(nn.Module): def __init__(self, esm_dim, additional_dim, cnn_out_channels, lstm_hidden_dim, num_classes, dropout_rate=0.42): super(AVP_Fusion, self).__init__() fused_input_dim = esm_dim + additional_dim self.parallel_extractor = ParallelFeatureExtractorWithAttention(fused_input_dim, cnn_out_channels, lstm_hidden_dim, dropout_rate) cnn_feature_dim = self.parallel_extractor.cnn_branch_output_dim bilstm_feature_dim = self.parallel_extractor.bilstm_branch_output_dim self.gating_network = nn.Sequential( nn.Linear(cnn_feature_dim + bilstm_feature_dim, 1), nn.Sigmoid() ) self.cnn_dim_matcher = nn.Linear(cnn_feature_dim, bilstm_feature_dim) classifier_input_dim = bilstm_feature_dim self.classifier = nn.Sequential( nn.Linear(classifier_input_dim, classifier_input_dim // 2), nn.BatchNorm1d(classifier_input_dim // 2), nn.ReLU(), nn.Dropout(dropout_rate), nn.Linear(classifier_input_dim // 2, num_classes) ) self.embedding_dim = classifier_input_dim def forward(self, esm_sequence_embedding, additional_features): seq_len = esm_sequence_embedding.size(1) expanded_additional_features = additional_features.unsqueeze(1).expand(-1, seq_len, -1) fused_sequence_embedding = torch.cat([esm_sequence_embedding, expanded_additional_features], dim=2) v_cnn, v_bilstm = self.parallel_extractor(fused_sequence_embedding) v_cnn_matched = self.cnn_dim_matcher(v_cnn) lambda_gate = self.gating_network(torch.cat([v_cnn, v_bilstm], dim=1)) final_embedding = lambda_gate * v_cnn_matched + (1 - lambda_gate) * v_bilstm logits = self.classifier(final_embedding) return logits, final_embedding