|
|
import torch |
|
|
import torch.nn as nn |
|
|
from transformers import PreTrainedModel, PretrainedConfig |
|
|
|
|
|
from peft import LoraConfig, get_peft_model, TaskType |
|
|
from transformers import EsmModel |
|
|
|
|
|
class TransHLA2Config(PretrainedConfig): |
|
|
model_type = "transhla2" |
|
|
def __init__(self, d_model=480, **kwargs): |
|
|
super().__init__(**kwargs) |
|
|
self.d_model = d_model |
|
|
|
|
|
|
|
|
class TransHLA2(PreTrainedModel): |
|
|
config_class = TransHLA2Config |
|
|
def __init__(self, config): |
|
|
super().__init__(config) |
|
|
self.model_name_or_path = "facebook/esm2_t12_35M_UR50D" |
|
|
self.tokenizer_name_or_path = "facebook/esm2_t12_35M_UR50D" |
|
|
self.peft_config = LoraConfig( |
|
|
target_modules=['query', 'out_proj', 'value', 'key', 'dense', 'regression'], |
|
|
task_type=TaskType.FEATURE_EXTRACTION, |
|
|
inference_mode=False, r=8, lora_alpha=32, lora_dropout=0.1 |
|
|
) |
|
|
d_model = config.d_model |
|
|
self.esm = EsmModel.from_pretrained(self.model_name_or_path) |
|
|
self.lora_esm = get_peft_model(self.esm, self.peft_config) |
|
|
self.fc_task = nn.Sequential( |
|
|
nn.Linear(d_model, d_model // 4), |
|
|
nn.BatchNorm1d(d_model // 4), |
|
|
nn.Dropout(0.2), |
|
|
nn.SiLU(), |
|
|
nn.Linear(d_model // 4, 32), |
|
|
nn.BatchNorm1d(32), |
|
|
) |
|
|
self.classifier = nn.Linear(32, 2) |
|
|
|
|
|
def forward(self, x_in): |
|
|
lora_outputs = self.lora_esm(x_in) |
|
|
last_hidden_state = lora_outputs.last_hidden_state |
|
|
out_linear = last_hidden_state.mean(dim=1) |
|
|
H = self.fc_task(out_linear) |
|
|
output = self.classifier(H) |
|
|
return output, last_hidden_state |
|
|
|