Spaces:
Running
on
Zero
Running
on
Zero
File size: 6,063 Bytes
7968cb0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 |
from transformers.models.esm.modeling_esm import EsmPreTrainedModel, EsmModel
import torch
import torch.nn as nn
from typing import Optional, Union, Tuple
from transformers.models.auto.modeling_auto import AutoModel
from transformers.models.auto.tokenization_auto import AutoTokenizer
from torch.nn import MSELoss
from transformers.modeling_outputs import TokenClassifierOutput
import numpy as np
import re
from utils.lora_utils import LoRAConfig, modify_with_lora
from models.enm_adaptor_heads import (
ENMAdaptedAttentionClassifier, ENMAdaptedDirectClassifier,
ENMAdaptedConvClassifier, ENMNoAdaptorClassifier
)
from peft import LoraConfig, inject_adapter_in_model
class EsmForTokenRegression(EsmPreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"pooler"]
_keys_to_ignore_on_load_missing = [r"position_ids"]
def __init__(self, config, class_config):
super().__init__(config)
self.num_labels = config.num_labels
self.add_pearson_loss = class_config.add_pearson_loss
self.add_sse_loss = class_config.add_sse_loss
self.esm = EsmModel(config, add_pooling_layer=False)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
if class_config.adaptor_architecture == 'attention':
self.classifier = ENMAdaptedAttentionClassifier(
config.hidden_size,
class_config.num_labels,
class_config.enm_embed_dim,
class_config.enm_att_heads
)
elif class_config.adaptor_architecture == 'direct':
self.classifier = ENMAdaptedDirectClassifier(
config.hidden_size,
class_config.num_labels
)
elif class_config.adaptor_architecture == 'conv':
self.classifier = ENMAdaptedConvClassifier(
config.hidden_size,
class_config.num_labels,
class_config.kernel_size,
class_config.enm_embed_dim,
class_config.num_layers
)
elif class_config.adaptor_architecture == 'no-adaptor':
self.classifier = ENMNoAdaptorClassifier(
config.hidden_size,
class_config.num_labels
)
else:
raise ValueError('Only attention, direct, conv and no-adaptor architectures are supported.')
self.init_weights()
def forward(
self,
enm_vals=None,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, TokenClassifierOutput]:
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.esm(
input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
sequence_output = outputs[0]
sequence_output = self.dropout(sequence_output)
logits = self.classifier(sequence_output, enm_vals, attention_mask)
if not return_dict:
output = (logits,) + outputs[2:]
return output
return TokenClassifierOutput(
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def ESM_classification_model(half_precision, class_config, lora_config):
# Load ESM and tokenizer
if not half_precision:
model = EsmModel.from_pretrained("facebook/esm2_t36_3B_UR50D")
tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t36_3B_UR50D")
elif half_precision and torch.cuda.is_available():
model = EsmModel.from_pretrained("facebook/esm2_t36_3B_UR50D", torch_dtype=torch.float16).to(torch.device('cuda'))
tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t36_3B_UR50D")
else:
raise ValueError('Half precision can be run on GPU only.')
# Create new Classifier model with ESM dimensions
class_model = EsmForTokenRegression(model.config, class_config)
# Set encoder weights to checkpoint weights
class_model.esm = model
# Delete the checkpoint model
del model
# Print number of trainable parameters
model_parameters = filter(lambda p: p.requires_grad, class_model.parameters())
params = sum([np.prod(p.size()) for p in model_parameters])
print("ESM_Classifier\nTrainable Parameter: " + str(params))
# Add model modification lora
esm_lora_peft_config = LoraConfig(
r=4, lora_alpha=1, bias="all", target_modules=["query","key","value","dense"]
)
# Add LoRA layers
class_model.esm = inject_adapter_in_model(esm_lora_peft_config, class_model.esm)
# Freeze Encoder (except LoRA)
for (param_name, param) in class_model.esm.named_parameters():
param.requires_grad = False
for (param_name, param) in class_model.esm.named_parameters():
if re.fullmatch(".*lora.*", param_name): #".*layer_norm.*|.*lora_[ab].*"
param.requires_grad = True
if re.fullmatch(".*layer_norm.*", param_name): #".*layer_norm.*|.*lora_[ab].*"
param.requires_grad = True
# Print trainable Parameter
model_parameters = filter(lambda p: p.requires_grad, class_model.parameters())
params = sum([np.prod(p.size()) for p in model_parameters])
print("ESM_LoRA_Classifier\nTrainable Parameter: " + str(params) + "\n")
return class_model, tokenizer |