import torch import torch.nn as nn from transformers import PreTrainedModel, PretrainedConfig from peft import LoraConfig, get_peft_model, TaskType from transformers import EsmModel class TransHLA2Config(PretrainedConfig): model_type = "transhla2" def __init__(self, d_model=480, **kwargs): super().__init__(**kwargs) self.d_model = d_model # 可加入其它自定义参数 class LoraESM(nn.Module): def __init__(self, d_model=480): super().__init__() self.model_name_or_path = "facebook/esm2_t12_35M_UR50D" self.tokenizer_name_or_path = "facebook/esm2_t12_35M_UR50D" self.peft_config = LoraConfig( target_modules=['query', 'out_proj', 'value', 'key', 'dense', 'regression'], task_type=TaskType.FEATURE_EXTRACTION, inference_mode=False, r=8, lora_alpha=32, lora_dropout=0.1 ) self.esm = EsmModel.from_pretrained(self.model_name_or_path) self.lora_esm = get_peft_model(self.esm, self.peft_config) self.fc_task = nn.Sequential( nn.Linear(d_model, d_model // 4), nn.BatchNorm1d(d_model // 4), nn.Dropout(0.2), nn.SiLU(), nn.Linear(d_model // 4, 32), nn.BatchNorm1d(32), ) self.classifier = nn.Linear(32, 2) def forward(self, x_in): lora_outputs = self.lora_esm(x_in) last_hidden_state = lora_outputs.last_hidden_state out_linear = last_hidden_state.mean(dim=1) H = self.fc_task(out_linear) output = self.classifier(H) return output, last_hidden_state lora_esm = LoraESM() class TransHLA2(PreTrainedModel): config_class = TransHLA2Config def __init__(self, config): super().__init__(config) n_layers = 4 n_head = 8 d_model = config.d_model d_ff = 64 cnn_num_channel = 256 region_embedding_size = 3 cnn_kernel_size = 3 cnn_padding_size = 1 cnn_stride = 1 pooling_size = 2 self.lora_esm = lora_esm self.region_cnn1 = nn.Conv1d(d_model, cnn_num_channel, region_embedding_size) self.region_cnn2 = nn.Conv1d(d_model, cnn_num_channel, region_embedding_size) self.padding1 = nn.ConstantPad1d((1, 1), 0) self.padding2 = nn.ConstantPad1d((0, 1), 0) self.relu = nn.SiLU() self.cnn1 = nn.Conv1d(cnn_num_channel, cnn_num_channel, kernel_size=cnn_kernel_size, padding=cnn_padding_size, stride=cnn_stride) self.cnn2 = nn.Conv1d(cnn_num_channel, cnn_num_channel, kernel_size=cnn_kernel_size, padding=cnn_padding_size, stride=cnn_stride) self.maxpooling = nn.MaxPool1d(kernel_size=pooling_size) self.epitope_transformer_layers = nn.TransformerEncoderLayer( d_model=d_model, nhead=n_head, dim_feedforward=d_ff, dropout=0.2) self.epitope_transformer_encoder = nn.TransformerEncoder( self.epitope_transformer_layers, num_layers=n_layers) self.hla_transformer_layers = nn.TransformerEncoderLayer( d_model=d_model, nhead=n_head, dim_feedforward=d_ff, dropout=0.2) self.hla_transformer_encoder = nn.TransformerEncoder( self.hla_transformer_layers, num_layers=n_layers) # Cross Attention layers self.cross_attention_epitope_layers = nn.ModuleList( [nn.MultiheadAttention(d_model, n_head, dropout=0.2) for _ in range(4)]) self.cross_attention_hla_layers = nn.ModuleList( [nn.MultiheadAttention(d_model, n_head, dropout=0.2) for _ in range(4)]) self.bn1 = nn.BatchNorm1d(cnn_num_channel) self.bn2 = nn.BatchNorm1d(cnn_num_channel) self.fc_task = nn.Sequential( nn.Linear(2*d_model + 2*cnn_num_channel, 2 * (d_model + cnn_num_channel) // 4), nn.BatchNorm1d(2 * (d_model + cnn_num_channel) // 4), nn.Dropout(0.2), nn.SiLU(), nn.Linear(2 * (d_model + cnn_num_channel) // 4, 96), nn.BatchNorm1d(96), ) self.classifier = nn.Linear(96, 2) def cnn_block1(self, x): return self.cnn1(self.relu(x)) def cnn_block2(self, x): x = self.padding2(x) px = self.maxpooling(x) x = self.relu(px) x = self.cnn1(x) x = self.relu(x) x = self.cnn1(x) x = px + x return x def structure_block1(self, x): return self.cnn2(self.relu(x)) def structure_block2(self, x): x = self.padding2(x) px = self.maxpooling(x) x = self.relu(px) x = self.cnn2(x) x = self.relu(x) x = self.cnn2(x) x = px + x return x def forward(self, epitope_in, hla_in): _, epitope_emb = self.lora_esm(epitope_in) _, hla_emb = self.lora_esm(hla_in) epitope_trans = self.epitope_transformer_encoder(epitope_emb.transpose(0, 1)) hla_trans = self.hla_transformer_encoder(hla_emb.transpose(0, 1)) # Cross Attention layers for cross_attention_epitope, cross_attention_hla in zip(self.cross_attention_epitope_layers, self.cross_attention_hla_layers): epitope_trans, _ = cross_attention_epitope(epitope_trans, hla_trans, hla_trans) hla_trans, _ = cross_attention_hla(hla_trans, epitope_trans, epitope_trans) # Mean Pooling epitope_mean = epitope_trans.mean(dim=0) hla_mean = hla_trans.mean(dim=0) epitope_cnn_emb = self.region_cnn1(epitope_emb.transpose(1, 2)) epitope_cnn_emb = self.padding1(epitope_cnn_emb) conv = epitope_cnn_emb + self.cnn_block1(self.cnn_block1(epitope_cnn_emb)) while conv.size(-1) >= 2: conv = self.cnn_block2(conv) epitope_cnn_out = torch.squeeze(conv, dim=-1) epitope_cnn_out = self.bn1(epitope_cnn_out) hla_cnn_emb = self.region_cnn2(hla_emb.transpose(1, 2)) hla_cnn_emb = self.padding1(hla_cnn_emb) hla_conv = hla_cnn_emb + self.structure_block1(self.structure_block1(hla_cnn_emb)) while hla_conv.size(-1) >= 2: hla_conv = self.structure_block2(hla_conv) hla_cnn_out = torch.squeeze(hla_conv, dim=-1) hla_cnn_out = self.bn2(hla_cnn_out) representation = torch.cat((epitope_mean, hla_mean, epitope_cnn_out, hla_cnn_out), dim=1) reduction_feature = self.fc_task(representation) logits_clsf = self.classifier(reduction_feature) logits_clsf = torch.nn.functional.softmax(logits_clsf, dim=1) return logits_clsf, reduction_feature # config = TransHLA2Config(d_model=480) # model = TransHLA2(config) # model.load_state_dict(torch.load('pytorch_model.pt')) # # 2. 保存为 transformers 兼容格式 # model.save_pretrained('pytorch_model.bin', safe_serialization=False) # from transformers import AutoConfig, AutoModel, CONFIG_MAPPING, MODEL_MAPPING # CONFIG_MAPPING.register("transhla2", TransHLA2Config) # MODEL_MAPPING.register(TransHLA2Config, TransHLA2)