File size: 7,093 Bytes
f98d04c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 |
import torch
import torch.nn as nn
from transformers import PreTrainedModel, PretrainedConfig
from peft import LoraConfig, get_peft_model, TaskType
from transformers import EsmModel
class TransHLA2Config(PretrainedConfig):
model_type = "transhla2"
def __init__(self, d_model=480, **kwargs):
super().__init__(**kwargs)
self.d_model = d_model
# 可加入其它自定义参数
class LoraESM(nn.Module):
def __init__(self, d_model=480):
super().__init__()
self.model_name_or_path = "facebook/esm2_t12_35M_UR50D"
self.tokenizer_name_or_path = "facebook/esm2_t12_35M_UR50D"
self.peft_config = LoraConfig(
target_modules=['query', 'out_proj', 'value', 'key', 'dense', 'regression'],
task_type=TaskType.FEATURE_EXTRACTION,
inference_mode=False, r=8, lora_alpha=32, lora_dropout=0.1
)
self.esm = EsmModel.from_pretrained(self.model_name_or_path)
self.lora_esm = get_peft_model(self.esm, self.peft_config)
self.fc_task = nn.Sequential(
nn.Linear(d_model, d_model // 4),
nn.BatchNorm1d(d_model // 4),
nn.Dropout(0.2),
nn.SiLU(),
nn.Linear(d_model // 4, 32),
nn.BatchNorm1d(32),
)
self.classifier = nn.Linear(32, 2)
def forward(self, x_in):
lora_outputs = self.lora_esm(x_in)
last_hidden_state = lora_outputs.last_hidden_state
out_linear = last_hidden_state.mean(dim=1)
H = self.fc_task(out_linear)
output = self.classifier(H)
return output, last_hidden_state
lora_esm = LoraESM()
class TransHLA2(PreTrainedModel):
config_class = TransHLA2Config
def __init__(self, config):
super().__init__(config)
n_layers = 4
n_head = 8
d_model = config.d_model
d_ff = 64
cnn_num_channel = 256
region_embedding_size = 3
cnn_kernel_size = 3
cnn_padding_size = 1
cnn_stride = 1
pooling_size = 2
self.lora_esm = lora_esm
self.region_cnn1 = nn.Conv1d(d_model, cnn_num_channel, region_embedding_size)
self.region_cnn2 = nn.Conv1d(d_model, cnn_num_channel, region_embedding_size)
self.padding1 = nn.ConstantPad1d((1, 1), 0)
self.padding2 = nn.ConstantPad1d((0, 1), 0)
self.relu = nn.SiLU()
self.cnn1 = nn.Conv1d(cnn_num_channel, cnn_num_channel, kernel_size=cnn_kernel_size,
padding=cnn_padding_size, stride=cnn_stride)
self.cnn2 = nn.Conv1d(cnn_num_channel, cnn_num_channel, kernel_size=cnn_kernel_size,
padding=cnn_padding_size, stride=cnn_stride)
self.maxpooling = nn.MaxPool1d(kernel_size=pooling_size)
self.epitope_transformer_layers = nn.TransformerEncoderLayer(
d_model=d_model, nhead=n_head, dim_feedforward=d_ff, dropout=0.2)
self.epitope_transformer_encoder = nn.TransformerEncoder(
self.epitope_transformer_layers, num_layers=n_layers)
self.hla_transformer_layers = nn.TransformerEncoderLayer(
d_model=d_model, nhead=n_head, dim_feedforward=d_ff, dropout=0.2)
self.hla_transformer_encoder = nn.TransformerEncoder(
self.hla_transformer_layers, num_layers=n_layers)
# Cross Attention layers
self.cross_attention_epitope_layers = nn.ModuleList(
[nn.MultiheadAttention(d_model, n_head, dropout=0.2) for _ in range(4)])
self.cross_attention_hla_layers = nn.ModuleList(
[nn.MultiheadAttention(d_model, n_head, dropout=0.2) for _ in range(4)])
self.bn1 = nn.BatchNorm1d(cnn_num_channel)
self.bn2 = nn.BatchNorm1d(cnn_num_channel)
self.fc_task = nn.Sequential(
nn.Linear(2*d_model + 2*cnn_num_channel, 2 * (d_model + cnn_num_channel) // 4),
nn.BatchNorm1d(2 * (d_model + cnn_num_channel) // 4),
nn.Dropout(0.2),
nn.SiLU(),
nn.Linear(2 * (d_model + cnn_num_channel) // 4, 96),
nn.BatchNorm1d(96),
)
self.classifier = nn.Linear(96, 2)
def cnn_block1(self, x):
return self.cnn1(self.relu(x))
def cnn_block2(self, x):
x = self.padding2(x)
px = self.maxpooling(x)
x = self.relu(px)
x = self.cnn1(x)
x = self.relu(x)
x = self.cnn1(x)
x = px + x
return x
def structure_block1(self, x):
return self.cnn2(self.relu(x))
def structure_block2(self, x):
x = self.padding2(x)
px = self.maxpooling(x)
x = self.relu(px)
x = self.cnn2(x)
x = self.relu(x)
x = self.cnn2(x)
x = px + x
return x
def forward(self, epitope_in, hla_in):
_, epitope_emb = self.lora_esm(epitope_in)
_, hla_emb = self.lora_esm(hla_in)
epitope_trans = self.epitope_transformer_encoder(epitope_emb.transpose(0, 1))
hla_trans = self.hla_transformer_encoder(hla_emb.transpose(0, 1))
# Cross Attention layers
for cross_attention_epitope, cross_attention_hla in zip(self.cross_attention_epitope_layers, self.cross_attention_hla_layers):
epitope_trans, _ = cross_attention_epitope(epitope_trans, hla_trans, hla_trans)
hla_trans, _ = cross_attention_hla(hla_trans, epitope_trans, epitope_trans)
# Mean Pooling
epitope_mean = epitope_trans.mean(dim=0)
hla_mean = hla_trans.mean(dim=0)
epitope_cnn_emb = self.region_cnn1(epitope_emb.transpose(1, 2))
epitope_cnn_emb = self.padding1(epitope_cnn_emb)
conv = epitope_cnn_emb + self.cnn_block1(self.cnn_block1(epitope_cnn_emb))
while conv.size(-1) >= 2:
conv = self.cnn_block2(conv)
epitope_cnn_out = torch.squeeze(conv, dim=-1)
epitope_cnn_out = self.bn1(epitope_cnn_out)
hla_cnn_emb = self.region_cnn2(hla_emb.transpose(1, 2))
hla_cnn_emb = self.padding1(hla_cnn_emb)
hla_conv = hla_cnn_emb + self.structure_block1(self.structure_block1(hla_cnn_emb))
while hla_conv.size(-1) >= 2:
hla_conv = self.structure_block2(hla_conv)
hla_cnn_out = torch.squeeze(hla_conv, dim=-1)
hla_cnn_out = self.bn2(hla_cnn_out)
representation = torch.cat((epitope_mean, hla_mean, epitope_cnn_out, hla_cnn_out), dim=1)
reduction_feature = self.fc_task(representation)
logits_clsf = self.classifier(reduction_feature)
logits_clsf = torch.nn.functional.softmax(logits_clsf, dim=1)
return logits_clsf, reduction_feature
# config = TransHLA2Config(d_model=480)
# model = TransHLA2(config)
# model.load_state_dict(torch.load('pytorch_model.pt'))
# # 2. 保存为 transformers 兼容格式
# model.save_pretrained('pytorch_model.bin', safe_serialization=False)
# from transformers import AutoConfig, AutoModel, CONFIG_MAPPING, MODEL_MAPPING
# CONFIG_MAPPING.register("transhla2", TransHLA2Config)
# MODEL_MAPPING.register(TransHLA2Config, TransHLA2)
|