SkywalkerLu commited on
Commit
c14acfa
·
verified ·
1 Parent(s): 14ae906

Update modeling_transhla2.py

Browse files
Files changed (1) hide show
  1. modeling_transhla2.py +143 -143
modeling_transhla2.py CHANGED
@@ -1,143 +1,143 @@
1
- import torch
2
- import torch.nn as nn
3
- from transformers import PreTrainedModel, PretrainedConfig
4
-
5
- from peft import LoraConfig, get_peft_model, TaskType
6
- from transformers import EsmModel
7
-
8
- class TransHLA2Config(PretrainedConfig):
9
- model_type = "transhla2"
10
- def __init__(self, d_model=480, **kwargs):
11
- super().__init__(**kwargs)
12
- self.d_model = d_model
13
- # 可加入其它自定义参数
14
- class TransHLA2Config(nn.Module):
15
- config_class = TransHLA2Config
16
- def __init__(self):
17
- super(TransHLA2Config, self).__init__()
18
-
19
- n_layers = 4
20
- n_head = 8
21
- d_model = 480
22
- d_ff = 64
23
- cnn_num_channel = 256
24
- region_embedding_size = 3
25
- cnn_kernel_size = 3
26
- cnn_padding_size = 1
27
- cnn_stride = 1
28
- pooling_size = 2
29
- self.model_name_or_path = "facebook/esm2_t12_35M_UR50D"
30
- self.tokenizer_name_or_path = "facebook/esm2_t12_35M_UR50D"
31
- self.peft_config = LoraConfig(
32
- target_modules=['query', 'out_proj', 'value', 'key', 'dense', 'regression'], task_type=TaskType.FEATURE_EXTRACTION, inference_mode=False, r=8, lora_alpha=32, lora_dropout=0.1)
33
- self.esm = EsmModel.from_pretrained(self.model_name_or_path)
34
- self.epitope_lora = get_peft_model(self.esm, self.peft_config)
35
- self.hla_lora = get_peft_model(self.esm, self.peft_config)
36
-
37
- self.region_cnn1 = nn.Conv1d(
38
- d_model, cnn_num_channel, region_embedding_size)
39
- self.region_cnn2 = nn.Conv1d(
40
- d_model, cnn_num_channel, region_embedding_size)
41
- self.padding1 = nn.ConstantPad1d((1, 1), 0)
42
- self.padding2 = nn.ConstantPad1d((0, 1), 0)
43
- self.relu = nn.SiLU()
44
- self.cnn1 = nn.Conv1d(cnn_num_channel, cnn_num_channel, kernel_size=cnn_kernel_size,
45
- padding=cnn_padding_size, stride=cnn_stride)
46
- self.cnn2 = nn.Conv1d(cnn_num_channel, cnn_num_channel, kernel_size=cnn_kernel_size,
47
- padding=cnn_padding_size, stride=cnn_stride)
48
- self.maxpooling = nn.MaxPool1d(kernel_size=pooling_size)
49
- self.epitope_transformer_layers = nn.TransformerEncoderLayer(
50
- d_model=d_model, nhead=n_head, dim_feedforward=d_ff, dropout=0.2)
51
- self.epitope_transformer_encoder = nn.TransformerEncoder(
52
- self.epitope_transformer_layers, num_layers=n_layers)
53
- self.hla_transformer_layers = nn.TransformerEncoderLayer(
54
- d_model=d_model, nhead=n_head, dim_feedforward=d_ff, dropout=0.2)
55
- self.hla_transformer_encoder = nn.TransformerEncoder(
56
- self.hla_transformer_layers, num_layers=n_layers)
57
-
58
- # Cross Attention layers
59
- self.cross_attention_epitope_layers = nn.ModuleList(
60
- [nn.MultiheadAttention(d_model, n_head, dropout=0.2) for _ in range(4)])
61
- self.cross_attention_hla_layers = nn.ModuleList(
62
- [nn.MultiheadAttention(d_model, n_head, dropout=0.2) for _ in range(4)])
63
-
64
- self.bn1 = nn.BatchNorm1d(cnn_num_channel)
65
- self.bn2 = nn.BatchNorm1d(cnn_num_channel)
66
- self.fc_task = nn.Sequential(
67
- nn.Linear(2*d_model + 2*cnn_num_channel, 2 * (d_model + cnn_num_channel) // 4),
68
- nn.BatchNorm1d(2 * (d_model + cnn_num_channel) // 4),
69
- nn.Dropout(0.2),
70
- nn.SiLU(),
71
- nn.Linear(2 * (d_model + cnn_num_channel) // 4, 96),
72
- nn.BatchNorm1d(96),
73
- )
74
- self.classifier = nn.Linear(96, 2)
75
-
76
- def cnn_block1(self, x):
77
- return self.cnn1(self.relu(x))
78
-
79
- def cnn_block2(self, x):
80
- x = self.padding2(x)
81
- px = self.maxpooling(x)
82
- x = self.relu(px)
83
- x = self.cnn1(x)
84
- x = self.relu(x)
85
- x = self.cnn1(x)
86
- x = px + x
87
- return x
88
-
89
- def structure_block1(self, x):
90
- return self.cnn2(self.relu(x))
91
-
92
- def structure_block2(self, x):
93
- x = self.padding2(x)
94
- px = self.maxpooling(x)
95
- x = self.relu(px)
96
- x = self.cnn2(x)
97
- x = self.relu(x)
98
- x = self.cnn2(x)
99
- x = px + x
100
- return x
101
-
102
- def forward(self, epitope_in, hla_in):
103
- epitope_emb = self.epitope_lora(epitope_in).last_hidden_state
104
- hla_emb = self.hla_lora(hla_in).last_hidden_state
105
-
106
- epitope_trans = self.epitope_transformer_encoder(epitope_emb.transpose(0, 1))
107
- hla_trans = self.hla_transformer_encoder(hla_emb.transpose(0, 1))
108
-
109
- # Cross Attention layers
110
- for cross_attention_epitope, cross_attention_hla in zip(self.cross_attention_epitope_layers, self.cross_attention_hla_layers):
111
- epitope_trans, _ = cross_attention_epitope(epitope_trans, hla_trans, hla_trans)
112
- hla_trans, _ = cross_attention_hla(hla_trans, epitope_trans, epitope_trans)
113
-
114
- # Mean Pooling
115
- epitope_mean = epitope_trans.mean(dim=0)
116
- hla_mean = hla_trans.mean(dim=0)
117
-
118
- epitope_cnn_emb = self.region_cnn1(epitope_emb.transpose(1, 2))
119
- epitope_cnn_emb = self.padding1(epitope_cnn_emb)
120
- conv = epitope_cnn_emb + self.cnn_block1(self.cnn_block1(epitope_cnn_emb))
121
- while conv.size(-1) >= 2:
122
- conv = self.cnn_block2(conv)
123
- epitope_cnn_out = torch.squeeze(conv, dim=-1)
124
- epitope_cnn_out = self.bn1(epitope_cnn_out)
125
-
126
- hla_cnn_emb = self.region_cnn2(hla_emb.transpose(1, 2))
127
- hla_cnn_emb = self.padding1(hla_cnn_emb)
128
- hla_conv = hla_cnn_emb + self.structure_block1(self.structure_block1(hla_cnn_emb))
129
- while hla_conv.size(-1) >= 2:
130
- hla_conv = self.structure_block2(hla_conv)
131
-
132
- hla_cnn_out = torch.squeeze(hla_conv, dim=-1)
133
- hla_cnn_out = self.bn2(hla_cnn_out)
134
- # CNN Blocks
135
-
136
- # Concatenate and pass through MLP
137
- representation = torch.cat((epitope_mean, hla_mean, epitope_cnn_out, hla_cnn_out), dim=1)
138
- reduction_feature = self.fc_task(representation)
139
- logits_clsf = self.classifier(reduction_feature)
140
- logits_clsf = torch.nn.functional.softmax(logits_clsf, dim=1)
141
- return logits_clsf, representation
142
-
143
-
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from transformers import PreTrainedModel, PretrainedConfig
4
+
5
+ from peft import LoraConfig, get_peft_model, TaskType
6
+ from transformers import EsmModel
7
+
8
+ class TransHLA2Config(PretrainedConfig):
9
+ model_type = "transhla2"
10
+ def __init__(self, d_model=480, **kwargs):
11
+ super().__init__(**kwargs)
12
+ self.d_model = d_model
13
+ # 可加入其它自定义参数
14
+ class TransHLA2(PreTrainedModel):
15
+ config_class = TransHLA2Config
16
+ def __init__(self):
17
+ super().__init__(config)
18
+
19
+ n_layers = 4
20
+ n_head = 8
21
+ d_model = 480
22
+ d_ff = 64
23
+ cnn_num_channel = 256
24
+ region_embedding_size = 3
25
+ cnn_kernel_size = 3
26
+ cnn_padding_size = 1
27
+ cnn_stride = 1
28
+ pooling_size = 2
29
+ self.model_name_or_path = "facebook/esm2_t12_35M_UR50D"
30
+ self.tokenizer_name_or_path = "facebook/esm2_t12_35M_UR50D"
31
+ self.peft_config = LoraConfig(
32
+ target_modules=['query', 'out_proj', 'value', 'key', 'dense', 'regression'], task_type=TaskType.FEATURE_EXTRACTION, inference_mode=False, r=8, lora_alpha=32, lora_dropout=0.1)
33
+ self.esm = EsmModel.from_pretrained(self.model_name_or_path)
34
+ self.epitope_lora = get_peft_model(self.esm, self.peft_config)
35
+ self.hla_lora = get_peft_model(self.esm, self.peft_config)
36
+
37
+ self.region_cnn1 = nn.Conv1d(
38
+ d_model, cnn_num_channel, region_embedding_size)
39
+ self.region_cnn2 = nn.Conv1d(
40
+ d_model, cnn_num_channel, region_embedding_size)
41
+ self.padding1 = nn.ConstantPad1d((1, 1), 0)
42
+ self.padding2 = nn.ConstantPad1d((0, 1), 0)
43
+ self.relu = nn.SiLU()
44
+ self.cnn1 = nn.Conv1d(cnn_num_channel, cnn_num_channel, kernel_size=cnn_kernel_size,
45
+ padding=cnn_padding_size, stride=cnn_stride)
46
+ self.cnn2 = nn.Conv1d(cnn_num_channel, cnn_num_channel, kernel_size=cnn_kernel_size,
47
+ padding=cnn_padding_size, stride=cnn_stride)
48
+ self.maxpooling = nn.MaxPool1d(kernel_size=pooling_size)
49
+ self.epitope_transformer_layers = nn.TransformerEncoderLayer(
50
+ d_model=d_model, nhead=n_head, dim_feedforward=d_ff, dropout=0.2)
51
+ self.epitope_transformer_encoder = nn.TransformerEncoder(
52
+ self.epitope_transformer_layers, num_layers=n_layers)
53
+ self.hla_transformer_layers = nn.TransformerEncoderLayer(
54
+ d_model=d_model, nhead=n_head, dim_feedforward=d_ff, dropout=0.2)
55
+ self.hla_transformer_encoder = nn.TransformerEncoder(
56
+ self.hla_transformer_layers, num_layers=n_layers)
57
+
58
+ # Cross Attention layers
59
+ self.cross_attention_epitope_layers = nn.ModuleList(
60
+ [nn.MultiheadAttention(d_model, n_head, dropout=0.2) for _ in range(4)])
61
+ self.cross_attention_hla_layers = nn.ModuleList(
62
+ [nn.MultiheadAttention(d_model, n_head, dropout=0.2) for _ in range(4)])
63
+
64
+ self.bn1 = nn.BatchNorm1d(cnn_num_channel)
65
+ self.bn2 = nn.BatchNorm1d(cnn_num_channel)
66
+ self.fc_task = nn.Sequential(
67
+ nn.Linear(2*d_model + 2*cnn_num_channel, 2 * (d_model + cnn_num_channel) // 4),
68
+ nn.BatchNorm1d(2 * (d_model + cnn_num_channel) // 4),
69
+ nn.Dropout(0.2),
70
+ nn.SiLU(),
71
+ nn.Linear(2 * (d_model + cnn_num_channel) // 4, 96),
72
+ nn.BatchNorm1d(96),
73
+ )
74
+ self.classifier = nn.Linear(96, 2)
75
+
76
+ def cnn_block1(self, x):
77
+ return self.cnn1(self.relu(x))
78
+
79
+ def cnn_block2(self, x):
80
+ x = self.padding2(x)
81
+ px = self.maxpooling(x)
82
+ x = self.relu(px)
83
+ x = self.cnn1(x)
84
+ x = self.relu(x)
85
+ x = self.cnn1(x)
86
+ x = px + x
87
+ return x
88
+
89
+ def structure_block1(self, x):
90
+ return self.cnn2(self.relu(x))
91
+
92
+ def structure_block2(self, x):
93
+ x = self.padding2(x)
94
+ px = self.maxpooling(x)
95
+ x = self.relu(px)
96
+ x = self.cnn2(x)
97
+ x = self.relu(x)
98
+ x = self.cnn2(x)
99
+ x = px + x
100
+ return x
101
+
102
+ def forward(self, epitope_in, hla_in):
103
+ epitope_emb = self.epitope_lora(epitope_in).last_hidden_state
104
+ hla_emb = self.hla_lora(hla_in).last_hidden_state
105
+
106
+ epitope_trans = self.epitope_transformer_encoder(epitope_emb.transpose(0, 1))
107
+ hla_trans = self.hla_transformer_encoder(hla_emb.transpose(0, 1))
108
+
109
+ # Cross Attention layers
110
+ for cross_attention_epitope, cross_attention_hla in zip(self.cross_attention_epitope_layers, self.cross_attention_hla_layers):
111
+ epitope_trans, _ = cross_attention_epitope(epitope_trans, hla_trans, hla_trans)
112
+ hla_trans, _ = cross_attention_hla(hla_trans, epitope_trans, epitope_trans)
113
+
114
+ # Mean Pooling
115
+ epitope_mean = epitope_trans.mean(dim=0)
116
+ hla_mean = hla_trans.mean(dim=0)
117
+
118
+ epitope_cnn_emb = self.region_cnn1(epitope_emb.transpose(1, 2))
119
+ epitope_cnn_emb = self.padding1(epitope_cnn_emb)
120
+ conv = epitope_cnn_emb + self.cnn_block1(self.cnn_block1(epitope_cnn_emb))
121
+ while conv.size(-1) >= 2:
122
+ conv = self.cnn_block2(conv)
123
+ epitope_cnn_out = torch.squeeze(conv, dim=-1)
124
+ epitope_cnn_out = self.bn1(epitope_cnn_out)
125
+
126
+ hla_cnn_emb = self.region_cnn2(hla_emb.transpose(1, 2))
127
+ hla_cnn_emb = self.padding1(hla_cnn_emb)
128
+ hla_conv = hla_cnn_emb + self.structure_block1(self.structure_block1(hla_cnn_emb))
129
+ while hla_conv.size(-1) >= 2:
130
+ hla_conv = self.structure_block2(hla_conv)
131
+
132
+ hla_cnn_out = torch.squeeze(hla_conv, dim=-1)
133
+ hla_cnn_out = self.bn2(hla_cnn_out)
134
+ # CNN Blocks
135
+
136
+ # Concatenate and pass through MLP
137
+ representation = torch.cat((epitope_mean, hla_mean, epitope_cnn_out, hla_cnn_out), dim=1)
138
+ reduction_feature = self.fc_task(representation)
139
+ logits_clsf = self.classifier(reduction_feature)
140
+ logits_clsf = torch.nn.functional.softmax(logits_clsf, dim=1)
141
+ return logits_clsf, representation
142
+
143
+