Spaces:
Sleeping
Sleeping
Update model_structure.py
Browse files- model_structure.py +18 -82
model_structure.py
CHANGED
|
@@ -1,30 +1,6 @@
|
|
| 1 |
-
import math
|
| 2 |
import torch
|
| 3 |
import torch.nn as nn
|
| 4 |
import torch.nn.functional as F
|
| 5 |
-
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
class AttentionFusion(nn.Module):
|
| 9 |
-
def __init__(self, feature_dims, hidden_dim):
|
| 10 |
-
super(AttentionFusion, self).__init__()
|
| 11 |
-
total_dim = sum(feature_dims)
|
| 12 |
-
self.attention_net = nn.Sequential(
|
| 13 |
-
nn.Linear(total_dim, hidden_dim),
|
| 14 |
-
nn.ReLU(),
|
| 15 |
-
nn.Linear(hidden_dim, len(feature_dims)),
|
| 16 |
-
nn.Softmax(dim=1)
|
| 17 |
-
)
|
| 18 |
-
self.feature_dims = feature_dims
|
| 19 |
-
|
| 20 |
-
def forward(self, feature_list):
|
| 21 |
-
concatenated_features = torch.cat(feature_list, dim=1)
|
| 22 |
-
attention_weights = self.attention_net(concatenated_features)
|
| 23 |
-
fused_feature = 0
|
| 24 |
-
for i, feature in enumerate(feature_list):
|
| 25 |
-
fused_feature += attention_weights[:, i].unsqueeze(1) * feature
|
| 26 |
-
return fused_feature
|
| 27 |
-
|
| 28 |
|
| 29 |
class SelfAttention(nn.Module):
|
| 30 |
def __init__(self, feature_dim):
|
|
@@ -34,33 +10,14 @@ class SelfAttention(nn.Module):
|
|
| 34 |
self.value = nn.Linear(feature_dim, feature_dim)
|
| 35 |
self.softmax = nn.Softmax(dim=-1)
|
| 36 |
|
| 37 |
-
def forward(self, x
|
| 38 |
-
"""
|
| 39 |
-
x: (batch, seq_len, dim)
|
| 40 |
-
mask: (batch, seq_len) with 1 for valid tokens and 0 for padding
|
| 41 |
-
"""
|
| 42 |
q = self.query(x)
|
| 43 |
k = self.key(x)
|
| 44 |
v = self.value(x)
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
key_mask = mask.unsqueeze(1).expand(-1, scores.size(1), -1)
|
| 50 |
-
scores = scores.masked_fill(key_mask == 0, -1e9)
|
| 51 |
-
|
| 52 |
-
attn = self.softmax(scores)
|
| 53 |
-
out = torch.bmm(attn, v) # (batch, seq_len, dim)
|
| 54 |
-
|
| 55 |
-
if mask is None:
|
| 56 |
-
return torch.mean(out, dim=1)
|
| 57 |
-
|
| 58 |
-
query_mask = mask.unsqueeze(-1).type_as(out)
|
| 59 |
-
out = out * query_mask
|
| 60 |
-
denom = query_mask.sum(dim=1).clamp(min=1.0)
|
| 61 |
-
pooled = out.sum(dim=1) / denom
|
| 62 |
-
return pooled
|
| 63 |
-
|
| 64 |
|
| 65 |
class ParallelFeatureExtractorWithAttention(nn.Module):
|
| 66 |
def __init__(self, input_dim, cnn_out_channels, lstm_hidden_dim, dropout_rate=0.3):
|
|
@@ -80,36 +37,22 @@ class ParallelFeatureExtractorWithAttention(nn.Module):
|
|
| 80 |
self.bilstm_attention = SelfAttention(lstm_hidden_dim * 2)
|
| 81 |
self.bilstm_branch_output_dim = lstm_hidden_dim * 2
|
| 82 |
|
| 83 |
-
def forward(self, sequence_embedding
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
cnn_out = F.relu(
|
| 87 |
-
cnn_out_permuted = cnn_out.permute(0, 2, 1)
|
| 88 |
-
v_cnn = self.cnn_attention(cnn_out_permuted
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
if mask is not None:
|
| 92 |
-
lengths = mask.sum(dim=1).to(torch.long).cpu()
|
| 93 |
-
packed = pack_padded_sequence(sequence_embedding, lengths, batch_first=True, enforce_sorted=False)
|
| 94 |
-
packed_out, _ = self.bilstm(packed)
|
| 95 |
-
lstm_out, _ = pad_packed_sequence(
|
| 96 |
-
packed_out, batch_first=True, total_length=sequence_embedding.size(1)
|
| 97 |
-
)
|
| 98 |
-
else:
|
| 99 |
-
lstm_out, _ = self.bilstm(sequence_embedding)
|
| 100 |
-
|
| 101 |
-
v_bilstm = self.bilstm_attention(lstm_out, mask=mask)
|
| 102 |
return v_cnn, v_bilstm
|
| 103 |
|
| 104 |
-
|
| 105 |
class AVP_Fusion(nn.Module):
|
| 106 |
def __init__(self, esm_dim, additional_dim, cnn_out_channels, lstm_hidden_dim, num_classes, dropout_rate=0.42):
|
| 107 |
super(AVP_Fusion, self).__init__()
|
| 108 |
fused_input_dim = esm_dim + additional_dim
|
| 109 |
-
self.parallel_extractor = ParallelFeatureExtractorWithAttention(
|
| 110 |
-
|
| 111 |
-
)
|
| 112 |
-
|
| 113 |
cnn_feature_dim = self.parallel_extractor.cnn_branch_output_dim
|
| 114 |
bilstm_feature_dim = self.parallel_extractor.bilstm_branch_output_dim
|
| 115 |
|
|
@@ -129,21 +72,14 @@ class AVP_Fusion(nn.Module):
|
|
| 129 |
)
|
| 130 |
self.embedding_dim = classifier_input_dim
|
| 131 |
|
| 132 |
-
def forward(self, esm_sequence_embedding, additional_features
|
| 133 |
seq_len = esm_sequence_embedding.size(1)
|
| 134 |
-
|
| 135 |
expanded_additional_features = additional_features.unsqueeze(1).expand(-1, seq_len, -1)
|
| 136 |
fused_sequence_embedding = torch.cat([esm_sequence_embedding, expanded_additional_features], dim=2)
|
| 137 |
-
|
| 138 |
-
v_cnn, v_bilstm = self.parallel_extractor(fused_sequence_embedding
|
| 139 |
-
|
| 140 |
v_cnn_matched = self.cnn_dim_matcher(v_cnn)
|
| 141 |
lambda_gate = self.gating_network(torch.cat([v_cnn, v_bilstm], dim=1))
|
| 142 |
final_embedding = lambda_gate * v_cnn_matched + (1 - lambda_gate) * v_bilstm
|
| 143 |
-
|
| 144 |
logits = self.classifier(final_embedding)
|
| 145 |
return logits, final_embedding
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
# Backward-compatible alias (do not remove)
|
| 149 |
-
AVP_HNCL_v3 = AVP_Fusion
|
|
|
|
|
|
|
| 1 |
import torch
|
| 2 |
import torch.nn as nn
|
| 3 |
import torch.nn.functional as F
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
|
| 5 |
class SelfAttention(nn.Module):
|
| 6 |
def __init__(self, feature_dim):
|
|
|
|
| 10 |
self.value = nn.Linear(feature_dim, feature_dim)
|
| 11 |
self.softmax = nn.Softmax(dim=-1)
|
| 12 |
|
| 13 |
+
def forward(self, x):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
q = self.query(x)
|
| 15 |
k = self.key(x)
|
| 16 |
v = self.value(x)
|
| 17 |
+
attention_scores = torch.bmm(q, k.transpose(1, 2))
|
| 18 |
+
attention_weights = self.softmax(attention_scores / (k.size(-1) ** 0.5))
|
| 19 |
+
weighted_features = torch.bmm(attention_weights, v)
|
| 20 |
+
return torch.mean(weighted_features, dim=1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
|
| 22 |
class ParallelFeatureExtractorWithAttention(nn.Module):
|
| 23 |
def __init__(self, input_dim, cnn_out_channels, lstm_hidden_dim, dropout_rate=0.3):
|
|
|
|
| 37 |
self.bilstm_attention = SelfAttention(lstm_hidden_dim * 2)
|
| 38 |
self.bilstm_branch_output_dim = lstm_hidden_dim * 2
|
| 39 |
|
| 40 |
+
def forward(self, sequence_embedding):
|
| 41 |
+
cnn_in = sequence_embedding.permute(0, 2, 1)
|
| 42 |
+
cnn_out = self.cnn(cnn_in)
|
| 43 |
+
cnn_out = F.relu(cnn_out)
|
| 44 |
+
cnn_out_permuted = cnn_out.permute(0, 2, 1)
|
| 45 |
+
v_cnn = self.cnn_attention(cnn_out_permuted)
|
| 46 |
+
lstm_out, _ = self.bilstm(sequence_embedding)
|
| 47 |
+
v_bilstm = self.bilstm_attention(lstm_out)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
return v_cnn, v_bilstm
|
| 49 |
|
|
|
|
| 50 |
class AVP_Fusion(nn.Module):
|
| 51 |
def __init__(self, esm_dim, additional_dim, cnn_out_channels, lstm_hidden_dim, num_classes, dropout_rate=0.42):
|
| 52 |
super(AVP_Fusion, self).__init__()
|
| 53 |
fused_input_dim = esm_dim + additional_dim
|
| 54 |
+
self.parallel_extractor = ParallelFeatureExtractorWithAttention(fused_input_dim, cnn_out_channels, lstm_hidden_dim, dropout_rate)
|
| 55 |
+
|
|
|
|
|
|
|
| 56 |
cnn_feature_dim = self.parallel_extractor.cnn_branch_output_dim
|
| 57 |
bilstm_feature_dim = self.parallel_extractor.bilstm_branch_output_dim
|
| 58 |
|
|
|
|
| 72 |
)
|
| 73 |
self.embedding_dim = classifier_input_dim
|
| 74 |
|
| 75 |
+
def forward(self, esm_sequence_embedding, additional_features):
|
| 76 |
seq_len = esm_sequence_embedding.size(1)
|
|
|
|
| 77 |
expanded_additional_features = additional_features.unsqueeze(1).expand(-1, seq_len, -1)
|
| 78 |
fused_sequence_embedding = torch.cat([esm_sequence_embedding, expanded_additional_features], dim=2)
|
| 79 |
+
|
| 80 |
+
v_cnn, v_bilstm = self.parallel_extractor(fused_sequence_embedding)
|
|
|
|
| 81 |
v_cnn_matched = self.cnn_dim_matcher(v_cnn)
|
| 82 |
lambda_gate = self.gating_network(torch.cat([v_cnn, v_bilstm], dim=1))
|
| 83 |
final_embedding = lambda_gate * v_cnn_matched + (1 - lambda_gate) * v_bilstm
|
|
|
|
| 84 |
logits = self.classifier(final_embedding)
|
| 85 |
return logits, final_embedding
|
|
|
|
|
|
|
|
|
|
|
|