File size: 1,706 Bytes
bd7b816 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 | import torch
import torch.nn as nn
from transformers import PreTrainedModel, PretrainedConfig
from peft import LoraConfig, get_peft_model, TaskType
from transformers import EsmModel
class TransHLA2Config(PretrainedConfig):
model_type = "transhla2"
def __init__(self, d_model=480, **kwargs):
super().__init__(**kwargs)
self.d_model = d_model
# 可加入其它自定义参数
class TransHLA2(PreTrainedModel):
config_class = TransHLA2Config
def __init__(self, config):
super().__init__(config)
self.model_name_or_path = "facebook/esm2_t12_35M_UR50D"
self.tokenizer_name_or_path = "facebook/esm2_t12_35M_UR50D"
self.peft_config = LoraConfig(
target_modules=['query', 'out_proj', 'value', 'key', 'dense', 'regression'],
task_type=TaskType.FEATURE_EXTRACTION,
inference_mode=False, r=8, lora_alpha=32, lora_dropout=0.1
)
d_model = config.d_model
self.esm = EsmModel.from_pretrained(self.model_name_or_path)
self.lora_esm = get_peft_model(self.esm, self.peft_config)
self.fc_task = nn.Sequential(
nn.Linear(d_model, d_model // 4),
nn.BatchNorm1d(d_model // 4),
nn.Dropout(0.2),
nn.SiLU(),
nn.Linear(d_model // 4, 32),
nn.BatchNorm1d(32),
)
self.classifier = nn.Linear(32, 2)
def forward(self, x_in):
lora_outputs = self.lora_esm(x_in)
last_hidden_state = lora_outputs.last_hidden_state
out_linear = last_hidden_state.mean(dim=1)
H = self.fc_task(out_linear)
output = self.classifier(H)
return output, last_hidden_state
|