import torch import torch.nn as nn from transformers import PreTrainedModel, PretrainedConfig from peft import LoraConfig, get_peft_model, TaskType from transformers import EsmModel class TransHLA2Config(PretrainedConfig): model_type = "transhla2" def __init__(self, d_model=480, **kwargs): super().__init__(**kwargs) self.d_model = d_model # 可加入其它自定义参数 class TransHLA2(PreTrainedModel): config_class = TransHLA2Config def __init__(self, config): super().__init__(config) self.model_name_or_path = "facebook/esm2_t12_35M_UR50D" self.tokenizer_name_or_path = "facebook/esm2_t12_35M_UR50D" self.peft_config = LoraConfig( target_modules=['query', 'out_proj', 'value', 'key', 'dense', 'regression'], task_type=TaskType.FEATURE_EXTRACTION, inference_mode=False, r=8, lora_alpha=32, lora_dropout=0.1 ) d_model = config.d_model self.esm = EsmModel.from_pretrained(self.model_name_or_path) self.lora_esm = get_peft_model(self.esm, self.peft_config) self.fc_task = nn.Sequential( nn.Linear(d_model, d_model // 4), nn.BatchNorm1d(d_model // 4), nn.Dropout(0.2), nn.SiLU(), nn.Linear(d_model // 4, 32), nn.BatchNorm1d(32), ) self.classifier = nn.Linear(32, 2) def forward(self, x_in): lora_outputs = self.lora_esm(x_in) last_hidden_state = lora_outputs.last_hidden_state out_linear = last_hidden_state.mean(dim=1) H = self.fc_task(out_linear) output = self.classifier(H) return output, last_hidden_state