SkywalkerLu commited on
Commit
276662d
·
verified ·
1 Parent(s): 3026373

Update modeling_transhla2.py

Browse files
Files changed (1) hide show
  1. modeling_transhla2.py +146 -67
modeling_transhla2.py CHANGED
@@ -1,84 +1,150 @@
1
  import torch
2
  import torch.nn as nn
3
  from transformers import PreTrainedModel, PretrainedConfig
4
-
5
  from peft import LoraConfig, get_peft_model, TaskType
6
  from transformers import EsmModel
7
 
 
8
  class TransHLA2Config(PretrainedConfig):
9
  model_type = "transhla2"
10
- def __init__(self, d_model=480, **kwargs):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  super().__init__(**kwargs)
12
  self.d_model = d_model
13
- # 可加入其它自定义参数
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  class TransHLA2(PreTrainedModel):
15
  config_class = TransHLA2Config
16
- def __init__(self):
 
17
  super().__init__(config)
 
18
 
19
- n_layers = 4
20
- n_head = 8
21
  d_model = config.d_model
22
- d_ff = 64
23
- cnn_num_channel = 256
24
- region_embedding_size = 3
25
- cnn_kernel_size = 3
26
- cnn_padding_size = 1
27
- cnn_stride = 1
28
- pooling_size = 2
29
- self.model_name_or_path = "facebook/esm2_t12_35M_UR50D"
30
- self.tokenizer_name_or_path = "facebook/esm2_t12_35M_UR50D"
 
 
 
31
  self.peft_config = LoraConfig(
32
- target_modules=['query', 'out_proj', 'value', 'key', 'dense', 'regression'], task_type=TaskType.FEATURE_EXTRACTION, inference_mode=False, r=8, lora_alpha=32, lora_dropout=0.1)
33
- self.esm = EsmModel.from_pretrained(self.model_name_or_path)
 
 
 
 
 
 
34
  self.epitope_lora = get_peft_model(self.esm, self.peft_config)
35
  self.hla_lora = get_peft_model(self.esm, self.peft_config)
36
 
37
- self.region_cnn1 = nn.Conv1d(
38
- d_model, cnn_num_channel, region_embedding_size)
39
- self.region_cnn2 = nn.Conv1d(
40
- d_model, cnn_num_channel, region_embedding_size)
41
  self.padding1 = nn.ConstantPad1d((1, 1), 0)
42
  self.padding2 = nn.ConstantPad1d((0, 1), 0)
43
  self.relu = nn.SiLU()
44
- self.cnn1 = nn.Conv1d(cnn_num_channel, cnn_num_channel, kernel_size=cnn_kernel_size,
45
- padding=cnn_padding_size, stride=cnn_stride)
46
- self.cnn2 = nn.Conv1d(cnn_num_channel, cnn_num_channel, kernel_size=cnn_kernel_size,
47
- padding=cnn_padding_size, stride=cnn_stride)
 
 
 
 
48
  self.maxpooling = nn.MaxPool1d(kernel_size=pooling_size)
 
 
49
  self.epitope_transformer_layers = nn.TransformerEncoderLayer(
50
- d_model=d_model, nhead=n_head, dim_feedforward=d_ff, dropout=0.2)
 
51
  self.epitope_transformer_encoder = nn.TransformerEncoder(
52
- self.epitope_transformer_layers, num_layers=n_layers)
 
53
  self.hla_transformer_layers = nn.TransformerEncoderLayer(
54
- d_model=d_model, nhead=n_head, dim_feedforward=d_ff, dropout=0.2)
 
55
  self.hla_transformer_encoder = nn.TransformerEncoder(
56
- self.hla_transformer_layers, num_layers=n_layers)
57
-
58
- # Cross Attention layers
 
59
  self.cross_attention_epitope_layers = nn.ModuleList(
60
- [nn.MultiheadAttention(d_model, n_head, dropout=0.2) for _ in range(4)])
 
61
  self.cross_attention_hla_layers = nn.ModuleList(
62
- [nn.MultiheadAttention(d_model, n_head, dropout=0.2) for _ in range(4)])
 
63
 
64
  self.bn1 = nn.BatchNorm1d(cnn_num_channel)
65
  self.bn2 = nn.BatchNorm1d(cnn_num_channel)
 
 
 
66
  self.fc_task = nn.Sequential(
67
- nn.Linear(2*d_model + 2*cnn_num_channel, 2 * (d_model + cnn_num_channel) // 4),
68
- nn.BatchNorm1d(2 * (d_model + cnn_num_channel) // 4),
69
  nn.Dropout(0.2),
70
  nn.SiLU(),
71
- nn.Linear(2 * (d_model + cnn_num_channel) // 4, 96),
72
  nn.BatchNorm1d(96),
73
  )
74
  self.classifier = nn.Linear(96, 2)
75
 
76
  def cnn_block1(self, x):
 
77
  return self.cnn1(self.relu(x))
78
 
79
  def cnn_block2(self, x):
80
- x = self.padding2(x)
81
- px = self.maxpooling(x)
 
82
  x = self.relu(px)
83
  x = self.cnn1(x)
84
  x = self.relu(x)
@@ -99,45 +165,58 @@ class TransHLA2(PreTrainedModel):
99
  x = px + x
100
  return x
101
 
102
- def forward(self, epitope_in, hla_in):
103
- epitope_emb = self.epitope_lora(epitope_in).last_hidden_state
104
- hla_emb = self.hla_lora(hla_in).last_hidden_state
105
-
106
- epitope_trans = self.epitope_transformer_encoder(epitope_emb.transpose(0, 1))
107
- hla_trans = self.hla_transformer_encoder(hla_emb.transpose(0, 1))
108
-
109
- # Cross Attention layers
110
- for cross_attention_epitope, cross_attention_hla in zip(self.cross_attention_epitope_layers, self.cross_attention_hla_layers):
111
- epitope_trans, _ = cross_attention_epitope(epitope_trans, hla_trans, hla_trans)
112
- hla_trans, _ = cross_attention_hla(hla_trans, epitope_trans, epitope_trans)
113
-
114
- # Mean Pooling
115
- epitope_mean = epitope_trans.mean(dim=0)
116
- hla_mean = hla_trans.mean(dim=0)
117
-
118
- epitope_cnn_emb = self.region_cnn1(epitope_emb.transpose(1, 2))
 
 
 
 
 
 
 
 
 
 
 
119
  epitope_cnn_emb = self.padding1(epitope_cnn_emb)
120
  conv = epitope_cnn_emb + self.cnn_block1(self.cnn_block1(epitope_cnn_emb))
 
121
  while conv.size(-1) >= 2:
122
  conv = self.cnn_block2(conv)
123
- epitope_cnn_out = torch.squeeze(conv, dim=-1)
124
  epitope_cnn_out = self.bn1(epitope_cnn_out)
125
 
126
- hla_cnn_emb = self.region_cnn2(hla_emb.transpose(1, 2))
 
127
  hla_cnn_emb = self.padding1(hla_cnn_emb)
128
  hla_conv = hla_cnn_emb + self.structure_block1(self.structure_block1(hla_cnn_emb))
129
  while hla_conv.size(-1) >= 2:
130
  hla_conv = self.structure_block2(hla_conv)
131
-
132
- hla_cnn_out = torch.squeeze(hla_conv, dim=-1)
133
  hla_cnn_out = self.bn2(hla_cnn_out)
134
- # CNN Blocks
135
-
136
- # Concatenate and pass through MLP
137
- representation = torch.cat((epitope_mean, hla_mean, epitope_cnn_out, hla_cnn_out), dim=1)
138
- reduction_feature = self.fc_task(representation)
139
- logits_clsf = self.classifier(reduction_feature)
140
- logits_clsf = torch.nn.functional.softmax(logits_clsf, dim=1)
141
- return logits_clsf, representation
142
 
 
 
 
 
143
 
 
 
 
 
 
 
1
  import torch
2
  import torch.nn as nn
3
  from transformers import PreTrainedModel, PretrainedConfig
 
4
  from peft import LoraConfig, get_peft_model, TaskType
5
  from transformers import EsmModel
6
 
7
+
8
  class TransHLA2Config(PretrainedConfig):
9
  model_type = "transhla2"
10
+
11
+ def __init__(
12
+ self,
13
+ d_model=480,
14
+ n_layers=4,
15
+ n_head=8,
16
+ d_ff=64,
17
+ cnn_num_channel=256,
18
+ region_embedding_size=3,
19
+ cnn_kernel_size=3,
20
+ cnn_padding_size=1,
21
+ cnn_stride=1,
22
+ pooling_size=2,
23
+ esm_model_name="facebook/esm2_t12_35M_UR50D",
24
+ lora_r=8,
25
+ lora_alpha=32,
26
+ lora_dropout=0.1,
27
+ lora_inference_mode=False,
28
+ target_modules=None,
29
+ return_prob=True, # 是否在 forward 返回概率(softmax),否则返回 logits
30
+ **kwargs,
31
+ ):
32
  super().__init__(**kwargs)
33
  self.d_model = d_model
34
+ self.n_layers = n_layers
35
+ self.n_head = n_head
36
+ self.d_ff = d_ff
37
+ self.cnn_num_channel = cnn_num_channel
38
+ self.region_embedding_size = region_embedding_size
39
+ self.cnn_kernel_size = cnn_kernel_size
40
+ self.cnn_padding_size = cnn_padding_size
41
+ self.cnn_stride = cnn_stride
42
+ self.pooling_size = pooling_size
43
+
44
+ self.esm_model_name = esm_model_name
45
+
46
+ self.lora_r = lora_r
47
+ self.lora_alpha = lora_alpha
48
+ self.lora_dropout = lora_dropout
49
+ self.lora_inference_mode = lora_inference_mode
50
+ self.target_modules = target_modules or ['query', 'out_proj', 'value', 'key', 'dense', 'regression']
51
+
52
+ self.return_prob = return_prob
53
+
54
+
55
  class TransHLA2(PreTrainedModel):
56
  config_class = TransHLA2Config
57
+
58
+ def __init__(self, config: TransHLA2Config):
59
  super().__init__(config)
60
+ self.config = config
61
 
 
 
62
  d_model = config.d_model
63
+ n_layers = config.n_layers
64
+ n_head = config.n_head
65
+ d_ff = config.d_ff
66
+ cnn_num_channel = config.cnn_num_channel
67
+ region_embedding_size = config.region_embedding_size
68
+ cnn_kernel_size = config.cnn_kernel_size
69
+ cnn_padding_size = config.cnn_padding_size
70
+ cnn_stride = config.cnn_stride
71
+ pooling_size = config.pooling_size
72
+
73
+ # Backbone + LoRA
74
+ self.esm = EsmModel.from_pretrained(config.esm_model_name)
75
  self.peft_config = LoraConfig(
76
+ target_modules=config.target_modules,
77
+ task_type=TaskType.FEATURE_EXTRACTION,
78
+ inference_mode=config.lora_inference_mode,
79
+ r=config.lora_r,
80
+ lora_alpha=config.lora_alpha,
81
+ lora_dropout=config.lora_dropout,
82
+ )
83
+ # 两套 LoRA 头,分别用于 epitope 和 hla 分支
84
  self.epitope_lora = get_peft_model(self.esm, self.peft_config)
85
  self.hla_lora = get_peft_model(self.esm, self.peft_config)
86
 
87
+ # CNN branches
88
+ self.region_cnn1 = nn.Conv1d(d_model, cnn_num_channel, region_embedding_size)
89
+ self.region_cnn2 = nn.Conv1d(d_model, cnn_num_channel, region_embedding_size)
 
90
  self.padding1 = nn.ConstantPad1d((1, 1), 0)
91
  self.padding2 = nn.ConstantPad1d((0, 1), 0)
92
  self.relu = nn.SiLU()
93
+ self.cnn1 = nn.Conv1d(
94
+ cnn_num_channel, cnn_num_channel,
95
+ kernel_size=cnn_kernel_size, padding=cnn_padding_size, stride=cnn_stride
96
+ )
97
+ self.cnn2 = nn.Conv1d(
98
+ cnn_num_channel, cnn_num_channel,
99
+ kernel_size=cnn_kernel_size, padding=cnn_padding_size, stride=cnn_stride
100
+ )
101
  self.maxpooling = nn.MaxPool1d(kernel_size=pooling_size)
102
+
103
+ # Transformer encoders (expect shape [S, B, D])
104
  self.epitope_transformer_layers = nn.TransformerEncoderLayer(
105
+ d_model=d_model, nhead=n_head, dim_feedforward=d_ff, dropout=0.2, batch_first=False
106
+ )
107
  self.epitope_transformer_encoder = nn.TransformerEncoder(
108
+ self.epitope_transformer_layers, num_layers=n_layers
109
+ )
110
  self.hla_transformer_layers = nn.TransformerEncoderLayer(
111
+ d_model=d_model, nhead=n_head, dim_feedforward=d_ff, dropout=0.2, batch_first=False
112
+ )
113
  self.hla_transformer_encoder = nn.TransformerEncoder(
114
+ self.hla_transformer_layers, num_layers=n_layers
115
+ )
116
+
117
+ # Cross Attention layers (expect [S, B, D])
118
  self.cross_attention_epitope_layers = nn.ModuleList(
119
+ [nn.MultiheadAttention(d_model, n_head, dropout=0.2, batch_first=False) for _ in range(4)]
120
+ )
121
  self.cross_attention_hla_layers = nn.ModuleList(
122
+ [nn.MultiheadAttention(d_model, n_head, dropout=0.2, batch_first=False) for _ in range(4)]
123
+ )
124
 
125
  self.bn1 = nn.BatchNorm1d(cnn_num_channel)
126
  self.bn2 = nn.BatchNorm1d(cnn_num_channel)
127
+
128
+ fused_dim = 2 * d_model + 2 * cnn_num_channel
129
+ hidden_dim = 2 * (d_model + cnn_num_channel) // 4
130
  self.fc_task = nn.Sequential(
131
+ nn.Linear(fused_dim, hidden_dim),
132
+ nn.BatchNorm1d(hidden_dim),
133
  nn.Dropout(0.2),
134
  nn.SiLU(),
135
+ nn.Linear(hidden_dim, 96),
136
  nn.BatchNorm1d(96),
137
  )
138
  self.classifier = nn.Linear(96, 2)
139
 
140
  def cnn_block1(self, x):
141
+ # x: (B, C, L)
142
  return self.cnn1(self.relu(x))
143
 
144
  def cnn_block2(self, x):
145
+ # x: (B, C, L)
146
+ x = self.padding2(x) # pad right by 1
147
+ px = self.maxpooling(x) # downsample
148
  x = self.relu(px)
149
  x = self.cnn1(x)
150
  x = self.relu(x)
 
165
  x = px + x
166
  return x
167
 
168
+ def forward(self, epitope_in, hla_in, return_dict=None):
169
+ # epitope_in, hla_in: 输入应为 ESM 的输入字典或张量(通常是 input_ids/attention_mask)
170
+ # 这里假定传入的是 ESM 的标准输入字典,例如:
171
+ # epitope_in = {"input_ids": ..., "attention_mask": ...}
172
+ # hla_in = {"input_ids": ..., "attention_mask": ...}
173
+
174
+ epitope_outputs = self.epitope_lora(**epitope_in)
175
+ hla_outputs = self.hla_lora(**hla_in)
176
+ # last_hidden_state: (B, L, D)
177
+ epitope_emb = epitope_outputs.last_hidden_state
178
+ hla_emb = hla_outputs.last_hidden_state
179
+
180
+ # Transformer encoder path (expects [S, B, D])
181
+ epitope_trans = self.epitope_transformer_encoder(epitope_emb.transpose(0, 1)) # (L, B, D)
182
+ hla_trans = self.hla_transformer_encoder(hla_emb.transpose(0, 1)) # (L, B, D)
183
+
184
+ # Cross Attention
185
+ for ca_e, ca_h in zip(self.cross_attention_epitope_layers, self.cross_attention_hla_layers):
186
+ epitope_trans, _ = ca_e(epitope_trans, hla_trans, hla_trans) # (L, B, D)
187
+ hla_trans, _ = ca_h(hla_trans, epitope_trans, epitope_trans) # (L, B, D)
188
+
189
+ # Mean Pooling over sequence length
190
+ epitope_mean = epitope_trans.mean(dim=0) # (B, D)
191
+ hla_mean = hla_trans.mean(dim=0) # (B, D)
192
+
193
+ # CNN branches expect (B, C, L). Convert ESM embeddings to (B, D, L)
194
+ epitope_cnn_emb = epitope_emb.transpose(1, 2) # (B, D, L)
195
+ epitope_cnn_emb = self.region_cnn1(epitope_cnn_emb) # (B, C, L')
196
  epitope_cnn_emb = self.padding1(epitope_cnn_emb)
197
  conv = epitope_cnn_emb + self.cnn_block1(self.cnn_block1(epitope_cnn_emb))
198
+ # 迭代收缩长度直到 < 2
199
  while conv.size(-1) >= 2:
200
  conv = self.cnn_block2(conv)
201
+ epitope_cnn_out = torch.squeeze(conv, dim=-1) # (B, C)
202
  epitope_cnn_out = self.bn1(epitope_cnn_out)
203
 
204
+ hla_cnn_emb = hla_emb.transpose(1, 2) # (B, D, L)
205
+ hla_cnn_emb = self.region_cnn2(hla_cnn_emb) # (B, C, L')
206
  hla_cnn_emb = self.padding1(hla_cnn_emb)
207
  hla_conv = hla_cnn_emb + self.structure_block1(self.structure_block1(hla_cnn_emb))
208
  while hla_conv.size(-1) >= 2:
209
  hla_conv = self.structure_block2(hla_conv)
210
+ hla_cnn_out = torch.squeeze(hla_conv, dim=-1) # (B, C)
 
211
  hla_cnn_out = self.bn2(hla_cnn_out)
 
 
 
 
 
 
 
 
212
 
213
+ # Fuse and classify
214
+ representation = torch.cat((epitope_mean, hla_mean, epitope_cnn_out, hla_cnn_out), dim=1) # (B, 2D+2C)
215
+ features = self.fc_task(representation) # (B, 96)
216
+ logits = self.classifier(features) # (B, 2)
217
 
218
+ if self.config.return_prob:
219
+ probs = torch.softmax(logits, dim=1)
220
+ return probs, representation
221
+ else:
222
+ return logits, representation