Wwwy1031 commited on
Commit
e5c3ed3
·
verified ·
1 Parent(s): 59588c5

Update model_structure.py

Browse files
Files changed (1) hide show
  1. model_structure.py +18 -82
model_structure.py CHANGED
@@ -1,30 +1,6 @@
1
- import math
2
  import torch
3
  import torch.nn as nn
4
  import torch.nn.functional as F
5
- from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
6
-
7
-
8
- class AttentionFusion(nn.Module):
9
- def __init__(self, feature_dims, hidden_dim):
10
- super(AttentionFusion, self).__init__()
11
- total_dim = sum(feature_dims)
12
- self.attention_net = nn.Sequential(
13
- nn.Linear(total_dim, hidden_dim),
14
- nn.ReLU(),
15
- nn.Linear(hidden_dim, len(feature_dims)),
16
- nn.Softmax(dim=1)
17
- )
18
- self.feature_dims = feature_dims
19
-
20
- def forward(self, feature_list):
21
- concatenated_features = torch.cat(feature_list, dim=1)
22
- attention_weights = self.attention_net(concatenated_features)
23
- fused_feature = 0
24
- for i, feature in enumerate(feature_list):
25
- fused_feature += attention_weights[:, i].unsqueeze(1) * feature
26
- return fused_feature
27
-
28
 
29
  class SelfAttention(nn.Module):
30
  def __init__(self, feature_dim):
@@ -34,33 +10,14 @@ class SelfAttention(nn.Module):
34
  self.value = nn.Linear(feature_dim, feature_dim)
35
  self.softmax = nn.Softmax(dim=-1)
36
 
37
- def forward(self, x, mask=None):
38
- """
39
- x: (batch, seq_len, dim)
40
- mask: (batch, seq_len) with 1 for valid tokens and 0 for padding
41
- """
42
  q = self.query(x)
43
  k = self.key(x)
44
  v = self.value(x)
45
-
46
- scores = torch.bmm(q, k.transpose(1, 2)) / math.sqrt(k.size(-1))
47
-
48
- if mask is not None:
49
- key_mask = mask.unsqueeze(1).expand(-1, scores.size(1), -1)
50
- scores = scores.masked_fill(key_mask == 0, -1e9)
51
-
52
- attn = self.softmax(scores)
53
- out = torch.bmm(attn, v) # (batch, seq_len, dim)
54
-
55
- if mask is None:
56
- return torch.mean(out, dim=1)
57
-
58
- query_mask = mask.unsqueeze(-1).type_as(out)
59
- out = out * query_mask
60
- denom = query_mask.sum(dim=1).clamp(min=1.0)
61
- pooled = out.sum(dim=1) / denom
62
- return pooled
63
-
64
 
65
  class ParallelFeatureExtractorWithAttention(nn.Module):
66
  def __init__(self, input_dim, cnn_out_channels, lstm_hidden_dim, dropout_rate=0.3):
@@ -80,36 +37,22 @@ class ParallelFeatureExtractorWithAttention(nn.Module):
80
  self.bilstm_attention = SelfAttention(lstm_hidden_dim * 2)
81
  self.bilstm_branch_output_dim = lstm_hidden_dim * 2
82
 
83
- def forward(self, sequence_embedding, mask=None):
84
- # CNN branch
85
- cnn_in = sequence_embedding.permute(0, 2, 1) # (batch, dim, seq_len)
86
- cnn_out = F.relu(self.cnn(cnn_in))
87
- cnn_out_permuted = cnn_out.permute(0, 2, 1) # (batch, seq_len, channels)
88
- v_cnn = self.cnn_attention(cnn_out_permuted, mask=mask)
89
-
90
- # BiLSTM branch (packed to ignore padding)
91
- if mask is not None:
92
- lengths = mask.sum(dim=1).to(torch.long).cpu()
93
- packed = pack_padded_sequence(sequence_embedding, lengths, batch_first=True, enforce_sorted=False)
94
- packed_out, _ = self.bilstm(packed)
95
- lstm_out, _ = pad_packed_sequence(
96
- packed_out, batch_first=True, total_length=sequence_embedding.size(1)
97
- )
98
- else:
99
- lstm_out, _ = self.bilstm(sequence_embedding)
100
-
101
- v_bilstm = self.bilstm_attention(lstm_out, mask=mask)
102
  return v_cnn, v_bilstm
103
 
104
-
105
  class AVP_Fusion(nn.Module):
106
  def __init__(self, esm_dim, additional_dim, cnn_out_channels, lstm_hidden_dim, num_classes, dropout_rate=0.42):
107
  super(AVP_Fusion, self).__init__()
108
  fused_input_dim = esm_dim + additional_dim
109
- self.parallel_extractor = ParallelFeatureExtractorWithAttention(
110
- fused_input_dim, cnn_out_channels, lstm_hidden_dim, dropout_rate
111
- )
112
-
113
  cnn_feature_dim = self.parallel_extractor.cnn_branch_output_dim
114
  bilstm_feature_dim = self.parallel_extractor.bilstm_branch_output_dim
115
 
@@ -129,21 +72,14 @@ class AVP_Fusion(nn.Module):
129
  )
130
  self.embedding_dim = classifier_input_dim
131
 
132
- def forward(self, esm_sequence_embedding, additional_features, attention_mask=None):
133
  seq_len = esm_sequence_embedding.size(1)
134
-
135
  expanded_additional_features = additional_features.unsqueeze(1).expand(-1, seq_len, -1)
136
  fused_sequence_embedding = torch.cat([esm_sequence_embedding, expanded_additional_features], dim=2)
137
-
138
- v_cnn, v_bilstm = self.parallel_extractor(fused_sequence_embedding, mask=attention_mask)
139
-
140
  v_cnn_matched = self.cnn_dim_matcher(v_cnn)
141
  lambda_gate = self.gating_network(torch.cat([v_cnn, v_bilstm], dim=1))
142
  final_embedding = lambda_gate * v_cnn_matched + (1 - lambda_gate) * v_bilstm
143
-
144
  logits = self.classifier(final_embedding)
145
  return logits, final_embedding
146
-
147
-
148
- # Backward-compatible alias (do not remove)
149
- AVP_HNCL_v3 = AVP_Fusion
 
 
1
  import torch
2
  import torch.nn as nn
3
  import torch.nn.functional as F
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
  class SelfAttention(nn.Module):
6
  def __init__(self, feature_dim):
 
10
  self.value = nn.Linear(feature_dim, feature_dim)
11
  self.softmax = nn.Softmax(dim=-1)
12
 
13
+ def forward(self, x):
 
 
 
 
14
  q = self.query(x)
15
  k = self.key(x)
16
  v = self.value(x)
17
+ attention_scores = torch.bmm(q, k.transpose(1, 2))
18
+ attention_weights = self.softmax(attention_scores / (k.size(-1) ** 0.5))
19
+ weighted_features = torch.bmm(attention_weights, v)
20
+ return torch.mean(weighted_features, dim=1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
  class ParallelFeatureExtractorWithAttention(nn.Module):
23
  def __init__(self, input_dim, cnn_out_channels, lstm_hidden_dim, dropout_rate=0.3):
 
37
  self.bilstm_attention = SelfAttention(lstm_hidden_dim * 2)
38
  self.bilstm_branch_output_dim = lstm_hidden_dim * 2
39
 
40
+ def forward(self, sequence_embedding):
41
+ cnn_in = sequence_embedding.permute(0, 2, 1)
42
+ cnn_out = self.cnn(cnn_in)
43
+ cnn_out = F.relu(cnn_out)
44
+ cnn_out_permuted = cnn_out.permute(0, 2, 1)
45
+ v_cnn = self.cnn_attention(cnn_out_permuted)
46
+ lstm_out, _ = self.bilstm(sequence_embedding)
47
+ v_bilstm = self.bilstm_attention(lstm_out)
 
 
 
 
 
 
 
 
 
 
 
48
  return v_cnn, v_bilstm
49
 
 
50
  class AVP_Fusion(nn.Module):
51
  def __init__(self, esm_dim, additional_dim, cnn_out_channels, lstm_hidden_dim, num_classes, dropout_rate=0.42):
52
  super(AVP_Fusion, self).__init__()
53
  fused_input_dim = esm_dim + additional_dim
54
+ self.parallel_extractor = ParallelFeatureExtractorWithAttention(fused_input_dim, cnn_out_channels, lstm_hidden_dim, dropout_rate)
55
+
 
 
56
  cnn_feature_dim = self.parallel_extractor.cnn_branch_output_dim
57
  bilstm_feature_dim = self.parallel_extractor.bilstm_branch_output_dim
58
 
 
72
  )
73
  self.embedding_dim = classifier_input_dim
74
 
75
+ def forward(self, esm_sequence_embedding, additional_features):
76
  seq_len = esm_sequence_embedding.size(1)
 
77
  expanded_additional_features = additional_features.unsqueeze(1).expand(-1, seq_len, -1)
78
  fused_sequence_embedding = torch.cat([esm_sequence_embedding, expanded_additional_features], dim=2)
79
+
80
+ v_cnn, v_bilstm = self.parallel_extractor(fused_sequence_embedding)
 
81
  v_cnn_matched = self.cnn_dim_matcher(v_cnn)
82
  lambda_gate = self.gating_network(torch.cat([v_cnn, v_bilstm], dim=1))
83
  final_embedding = lambda_gate * v_cnn_matched + (1 - lambda_gate) * v_bilstm
 
84
  logits = self.classifier(final_embedding)
85
  return logits, final_embedding