TriStageHLA-PRE / modeling_transhla2.py
SkywalkerLu's picture
Update modeling_transhla2.py
bd7b816 verified
import torch
import torch.nn as nn
from transformers import PreTrainedModel, PretrainedConfig
from peft import LoraConfig, get_peft_model, TaskType
from transformers import EsmModel
class TransHLA2Config(PretrainedConfig):
model_type = "transhla2"
def __init__(self, d_model=480, **kwargs):
super().__init__(**kwargs)
self.d_model = d_model
# 可加入其它自定义参数
class TransHLA2(PreTrainedModel):
config_class = TransHLA2Config
def __init__(self, config):
super().__init__(config)
self.model_name_or_path = "facebook/esm2_t12_35M_UR50D"
self.tokenizer_name_or_path = "facebook/esm2_t12_35M_UR50D"
self.peft_config = LoraConfig(
target_modules=['query', 'out_proj', 'value', 'key', 'dense', 'regression'],
task_type=TaskType.FEATURE_EXTRACTION,
inference_mode=False, r=8, lora_alpha=32, lora_dropout=0.1
)
d_model = config.d_model
self.esm = EsmModel.from_pretrained(self.model_name_or_path)
self.lora_esm = get_peft_model(self.esm, self.peft_config)
self.fc_task = nn.Sequential(
nn.Linear(d_model, d_model // 4),
nn.BatchNorm1d(d_model // 4),
nn.Dropout(0.2),
nn.SiLU(),
nn.Linear(d_model // 4, 32),
nn.BatchNorm1d(32),
)
self.classifier = nn.Linear(32, 2)
def forward(self, x_in):
lora_outputs = self.lora_esm(x_in)
last_hidden_state = lora_outputs.last_hidden_state
out_linear = last_hidden_state.mean(dim=1)
H = self.fc_task(out_linear)
output = self.classifier(H)
return output, last_hidden_state