|
|
from transformers.models.esm.modeling_esm import EsmPreTrainedModel, EsmModel |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from typing import Optional, Union, Tuple |
|
|
from transformers.models.auto.modeling_auto import AutoModel |
|
|
from transformers.models.auto.tokenization_auto import AutoTokenizer |
|
|
from torch.nn import MSELoss |
|
|
from transformers.modeling_outputs import TokenClassifierOutput |
|
|
import numpy as np |
|
|
import re |
|
|
from utils.lora_utils import LoRAConfig, modify_with_lora |
|
|
from models.enm_adaptor_heads import ( |
|
|
ENMAdaptedAttentionClassifier, ENMAdaptedDirectClassifier, |
|
|
ENMAdaptedConvClassifier, ENMNoAdaptorClassifier |
|
|
) |
|
|
from peft import LoraConfig, inject_adapter_in_model |
|
|
|
|
|
class EsmForTokenRegression(EsmPreTrainedModel): |
|
|
_keys_to_ignore_on_load_unexpected = [r"pooler"] |
|
|
_keys_to_ignore_on_load_missing = [r"position_ids"] |
|
|
|
|
|
def __init__(self, config, class_config): |
|
|
super().__init__(config) |
|
|
self.num_labels = config.num_labels |
|
|
self.add_pearson_loss = class_config.add_pearson_loss |
|
|
self.add_sse_loss = class_config.add_sse_loss |
|
|
|
|
|
self.esm = EsmModel(config, add_pooling_layer=False) |
|
|
self.dropout = nn.Dropout(config.hidden_dropout_prob) |
|
|
|
|
|
if class_config.adaptor_architecture == 'attention': |
|
|
self.classifier = ENMAdaptedAttentionClassifier( |
|
|
config.hidden_size, |
|
|
class_config.num_labels, |
|
|
class_config.enm_embed_dim, |
|
|
class_config.enm_att_heads |
|
|
) |
|
|
elif class_config.adaptor_architecture == 'direct': |
|
|
self.classifier = ENMAdaptedDirectClassifier( |
|
|
config.hidden_size, |
|
|
class_config.num_labels |
|
|
) |
|
|
elif class_config.adaptor_architecture == 'conv': |
|
|
self.classifier = ENMAdaptedConvClassifier( |
|
|
config.hidden_size, |
|
|
class_config.num_labels, |
|
|
class_config.kernel_size, |
|
|
class_config.enm_embed_dim, |
|
|
class_config.num_layers |
|
|
) |
|
|
elif class_config.adaptor_architecture == 'no-adaptor': |
|
|
self.classifier = ENMNoAdaptorClassifier( |
|
|
config.hidden_size, |
|
|
class_config.num_labels |
|
|
) |
|
|
else: |
|
|
raise ValueError('Only attention, direct, conv and no-adaptor architectures are supported.') |
|
|
|
|
|
self.init_weights() |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
enm_vals=None, |
|
|
input_ids: Optional[torch.LongTensor] = None, |
|
|
attention_mask: Optional[torch.Tensor] = None, |
|
|
position_ids: Optional[torch.LongTensor] = None, |
|
|
head_mask: Optional[torch.Tensor] = None, |
|
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
|
labels: Optional[torch.FloatTensor] = None, |
|
|
output_attentions: Optional[bool] = None, |
|
|
output_hidden_states: Optional[bool] = None, |
|
|
return_dict: Optional[bool] = None, |
|
|
) -> Union[Tuple, TokenClassifierOutput]: |
|
|
|
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
|
|
outputs = self.esm( |
|
|
input_ids, |
|
|
attention_mask=attention_mask, |
|
|
position_ids=position_ids, |
|
|
head_mask=head_mask, |
|
|
inputs_embeds=inputs_embeds, |
|
|
output_attentions=output_attentions, |
|
|
output_hidden_states=output_hidden_states, |
|
|
return_dict=return_dict, |
|
|
) |
|
|
|
|
|
sequence_output = outputs[0] |
|
|
sequence_output = self.dropout(sequence_output) |
|
|
|
|
|
logits = self.classifier(sequence_output, enm_vals, attention_mask) |
|
|
|
|
|
if not return_dict: |
|
|
output = (logits,) + outputs[2:] |
|
|
return output |
|
|
|
|
|
return TokenClassifierOutput( |
|
|
logits=logits, |
|
|
hidden_states=outputs.hidden_states, |
|
|
attentions=outputs.attentions, |
|
|
) |
|
|
|
|
|
def ESM_classification_model(half_precision, class_config, lora_config): |
|
|
|
|
|
if not half_precision: |
|
|
model = EsmModel.from_pretrained("facebook/esm2_t36_3B_UR50D") |
|
|
tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t36_3B_UR50D") |
|
|
elif half_precision and torch.cuda.is_available(): |
|
|
model = EsmModel.from_pretrained("facebook/esm2_t36_3B_UR50D", torch_dtype=torch.float16).to(torch.device('cuda')) |
|
|
tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t36_3B_UR50D") |
|
|
else: |
|
|
raise ValueError('Half precision can be run on GPU only.') |
|
|
|
|
|
|
|
|
class_model = EsmForTokenRegression(model.config, class_config) |
|
|
|
|
|
|
|
|
class_model.esm = model |
|
|
|
|
|
|
|
|
del model |
|
|
|
|
|
|
|
|
model_parameters = filter(lambda p: p.requires_grad, class_model.parameters()) |
|
|
params = sum([np.prod(p.size()) for p in model_parameters]) |
|
|
print("ESM_Classifier\nTrainable Parameter: " + str(params)) |
|
|
|
|
|
|
|
|
esm_lora_peft_config = LoraConfig( |
|
|
r=4, lora_alpha=1, bias="all", target_modules=["query","key","value","dense"] |
|
|
) |
|
|
|
|
|
|
|
|
class_model.esm = inject_adapter_in_model(esm_lora_peft_config, class_model.esm) |
|
|
|
|
|
|
|
|
for (param_name, param) in class_model.esm.named_parameters(): |
|
|
param.requires_grad = False |
|
|
|
|
|
for (param_name, param) in class_model.esm.named_parameters(): |
|
|
if re.fullmatch(".*lora.*", param_name): |
|
|
param.requires_grad = True |
|
|
if re.fullmatch(".*layer_norm.*", param_name): |
|
|
param.requires_grad = True |
|
|
|
|
|
model_parameters = filter(lambda p: p.requires_grad, class_model.parameters()) |
|
|
params = sum([np.prod(p.size()) for p in model_parameters]) |
|
|
print("ESM_LoRA_Classifier\nTrainable Parameter: " + str(params) + "\n") |
|
|
|
|
|
return class_model, tokenizer |