Upload folder using huggingface_hub
Browse files- ESM_per_token.py +150 -0
- T5_encoder_per_token.py +172 -0
- __pycache__/T5_encoder_per_token.cpython-313.pyc +0 -0
- __pycache__/T5_encoder_per_token.cpython-39.pyc +0 -0
- __pycache__/enm_adaptor_heads.cpython-313.pyc +0 -0
- __pycache__/enm_adaptor_heads.cpython-39.pyc +0 -0
- enm_adaptor_heads.py +85 -0
- weights/.gitkeep +0 -0
- weights/flexpert_3d_weights.bin +3 -0
- weights/flexpert_seq_weights.bin +3 -0
ESM_per_token.py
ADDED
|
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from transformers.models.esm.modeling_esm import EsmPreTrainedModel, EsmModel
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
from typing import Optional, Union, Tuple
|
| 5 |
+
from transformers.models.auto.modeling_auto import AutoModel
|
| 6 |
+
from transformers.models.auto.tokenization_auto import AutoTokenizer
|
| 7 |
+
from torch.nn import MSELoss
|
| 8 |
+
from transformers.modeling_outputs import TokenClassifierOutput
|
| 9 |
+
import numpy as np
|
| 10 |
+
import re
|
| 11 |
+
from utils.lora_utils import LoRAConfig, modify_with_lora
|
| 12 |
+
from models.enm_adaptor_heads import (
|
| 13 |
+
ENMAdaptedAttentionClassifier, ENMAdaptedDirectClassifier,
|
| 14 |
+
ENMAdaptedConvClassifier, ENMNoAdaptorClassifier
|
| 15 |
+
)
|
| 16 |
+
from peft import LoraConfig, inject_adapter_in_model
|
| 17 |
+
|
| 18 |
+
class EsmForTokenRegression(EsmPreTrainedModel):
|
| 19 |
+
_keys_to_ignore_on_load_unexpected = [r"pooler"]
|
| 20 |
+
_keys_to_ignore_on_load_missing = [r"position_ids"]
|
| 21 |
+
|
| 22 |
+
def __init__(self, config, class_config):
|
| 23 |
+
super().__init__(config)
|
| 24 |
+
self.num_labels = config.num_labels
|
| 25 |
+
self.add_pearson_loss = class_config.add_pearson_loss
|
| 26 |
+
self.add_sse_loss = class_config.add_sse_loss
|
| 27 |
+
|
| 28 |
+
self.esm = EsmModel(config, add_pooling_layer=False)
|
| 29 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
| 30 |
+
|
| 31 |
+
if class_config.adaptor_architecture == 'attention':
|
| 32 |
+
self.classifier = ENMAdaptedAttentionClassifier(
|
| 33 |
+
config.hidden_size,
|
| 34 |
+
class_config.num_labels,
|
| 35 |
+
class_config.enm_embed_dim,
|
| 36 |
+
class_config.enm_att_heads
|
| 37 |
+
)
|
| 38 |
+
elif class_config.adaptor_architecture == 'direct':
|
| 39 |
+
self.classifier = ENMAdaptedDirectClassifier(
|
| 40 |
+
config.hidden_size,
|
| 41 |
+
class_config.num_labels
|
| 42 |
+
)
|
| 43 |
+
elif class_config.adaptor_architecture == 'conv':
|
| 44 |
+
self.classifier = ENMAdaptedConvClassifier(
|
| 45 |
+
config.hidden_size,
|
| 46 |
+
class_config.num_labels,
|
| 47 |
+
class_config.kernel_size,
|
| 48 |
+
class_config.enm_embed_dim,
|
| 49 |
+
class_config.num_layers
|
| 50 |
+
)
|
| 51 |
+
elif class_config.adaptor_architecture == 'no-adaptor':
|
| 52 |
+
self.classifier = ENMNoAdaptorClassifier(
|
| 53 |
+
config.hidden_size,
|
| 54 |
+
class_config.num_labels
|
| 55 |
+
)
|
| 56 |
+
else:
|
| 57 |
+
raise ValueError('Only attention, direct, conv and no-adaptor architectures are supported.')
|
| 58 |
+
|
| 59 |
+
self.init_weights()
|
| 60 |
+
|
| 61 |
+
def forward(
|
| 62 |
+
self,
|
| 63 |
+
enm_vals=None,
|
| 64 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 65 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 66 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 67 |
+
head_mask: Optional[torch.Tensor] = None,
|
| 68 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 69 |
+
labels: Optional[torch.FloatTensor] = None,
|
| 70 |
+
output_attentions: Optional[bool] = None,
|
| 71 |
+
output_hidden_states: Optional[bool] = None,
|
| 72 |
+
return_dict: Optional[bool] = None,
|
| 73 |
+
) -> Union[Tuple, TokenClassifierOutput]:
|
| 74 |
+
|
| 75 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 76 |
+
|
| 77 |
+
outputs = self.esm(
|
| 78 |
+
input_ids,
|
| 79 |
+
attention_mask=attention_mask,
|
| 80 |
+
position_ids=position_ids,
|
| 81 |
+
head_mask=head_mask,
|
| 82 |
+
inputs_embeds=inputs_embeds,
|
| 83 |
+
output_attentions=output_attentions,
|
| 84 |
+
output_hidden_states=output_hidden_states,
|
| 85 |
+
return_dict=return_dict,
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
sequence_output = outputs[0]
|
| 89 |
+
sequence_output = self.dropout(sequence_output)
|
| 90 |
+
|
| 91 |
+
logits = self.classifier(sequence_output, enm_vals, attention_mask)
|
| 92 |
+
|
| 93 |
+
if not return_dict:
|
| 94 |
+
output = (logits,) + outputs[2:]
|
| 95 |
+
return output
|
| 96 |
+
|
| 97 |
+
return TokenClassifierOutput(
|
| 98 |
+
logits=logits,
|
| 99 |
+
hidden_states=outputs.hidden_states,
|
| 100 |
+
attentions=outputs.attentions,
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
def ESM_classification_model(half_precision, class_config, lora_config):
|
| 104 |
+
# Load ESM and tokenizer
|
| 105 |
+
if not half_precision:
|
| 106 |
+
model = EsmModel.from_pretrained("facebook/esm2_t36_3B_UR50D")
|
| 107 |
+
tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t36_3B_UR50D")
|
| 108 |
+
elif half_precision and torch.cuda.is_available():
|
| 109 |
+
model = EsmModel.from_pretrained("facebook/esm2_t36_3B_UR50D", torch_dtype=torch.float16).to(torch.device('cuda'))
|
| 110 |
+
tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t36_3B_UR50D")
|
| 111 |
+
else:
|
| 112 |
+
raise ValueError('Half precision can be run on GPU only.')
|
| 113 |
+
|
| 114 |
+
# Create new Classifier model with ESM dimensions
|
| 115 |
+
class_model = EsmForTokenRegression(model.config, class_config)
|
| 116 |
+
|
| 117 |
+
# Set encoder weights to checkpoint weights
|
| 118 |
+
class_model.esm = model
|
| 119 |
+
|
| 120 |
+
# Delete the checkpoint model
|
| 121 |
+
del model
|
| 122 |
+
|
| 123 |
+
# Print number of trainable parameters
|
| 124 |
+
model_parameters = filter(lambda p: p.requires_grad, class_model.parameters())
|
| 125 |
+
params = sum([np.prod(p.size()) for p in model_parameters])
|
| 126 |
+
print("ESM_Classifier\nTrainable Parameter: " + str(params))
|
| 127 |
+
|
| 128 |
+
# Add model modification lora
|
| 129 |
+
esm_lora_peft_config = LoraConfig(
|
| 130 |
+
r=4, lora_alpha=1, bias="all", target_modules=["query","key","value","dense"]
|
| 131 |
+
)
|
| 132 |
+
|
| 133 |
+
# Add LoRA layers
|
| 134 |
+
class_model.esm = inject_adapter_in_model(esm_lora_peft_config, class_model.esm)
|
| 135 |
+
|
| 136 |
+
# Freeze Encoder (except LoRA)
|
| 137 |
+
for (param_name, param) in class_model.esm.named_parameters():
|
| 138 |
+
param.requires_grad = False
|
| 139 |
+
|
| 140 |
+
for (param_name, param) in class_model.esm.named_parameters():
|
| 141 |
+
if re.fullmatch(".*lora.*", param_name): #".*layer_norm.*|.*lora_[ab].*"
|
| 142 |
+
param.requires_grad = True
|
| 143 |
+
if re.fullmatch(".*layer_norm.*", param_name): #".*layer_norm.*|.*lora_[ab].*"
|
| 144 |
+
param.requires_grad = True
|
| 145 |
+
# Print trainable Parameter
|
| 146 |
+
model_parameters = filter(lambda p: p.requires_grad, class_model.parameters())
|
| 147 |
+
params = sum([np.prod(p.size()) for p in model_parameters])
|
| 148 |
+
print("ESM_LoRA_Classifier\nTrainable Parameter: " + str(params) + "\n")
|
| 149 |
+
|
| 150 |
+
return class_model, tokenizer
|
T5_encoder_per_token.py
ADDED
|
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import copy
|
| 5 |
+
import re
|
| 6 |
+
from transformers import T5Config, T5PreTrainedModel, T5EncoderModel, T5Tokenizer
|
| 7 |
+
from transformers.models.t5.modeling_t5 import T5Stack
|
| 8 |
+
from transformers.modeling_outputs import TokenClassifierOutput
|
| 9 |
+
from transformers.utils.model_parallel_utils import assert_device_map, get_device_map
|
| 10 |
+
from models.enm_adaptor_heads import ENMAdaptedAttentionClassifier, ENMAdaptedDirectClassifier, ENMAdaptedConvClassifier, ENMNoAdaptorClassifier
|
| 11 |
+
from utils.lora_utils import LoRAConfig, modify_with_lora
|
| 12 |
+
|
| 13 |
+
class T5EncoderForTokenClassification(T5PreTrainedModel):
|
| 14 |
+
|
| 15 |
+
def __init__(self, config: T5Config, class_config):
|
| 16 |
+
super().__init__(config)
|
| 17 |
+
self.num_labels = class_config.num_labels
|
| 18 |
+
self.config = config
|
| 19 |
+
self.add_pearson_loss = class_config.add_pearson_loss
|
| 20 |
+
self.add_sse_loss = class_config.add_sse_loss
|
| 21 |
+
self.shared = nn.Embedding(config.vocab_size, config.d_model)
|
| 22 |
+
|
| 23 |
+
encoder_config = copy.deepcopy(config)
|
| 24 |
+
encoder_config.use_cache = False
|
| 25 |
+
encoder_config.is_encoder_decoder = False
|
| 26 |
+
self.encoder = T5Stack(encoder_config, self.shared)
|
| 27 |
+
|
| 28 |
+
self.dropout = nn.Dropout(class_config.dropout_rate)
|
| 29 |
+
if class_config.adaptor_architecture == 'attention':
|
| 30 |
+
self.classifier = ENMAdaptedAttentionClassifier(config.hidden_size, class_config.num_labels, class_config.enm_embed_dim, class_config.enm_att_heads) #nn.Linear(config.hidden_size, class_config.num_labels)
|
| 31 |
+
elif class_config.adaptor_architecture == 'direct':
|
| 32 |
+
self.classifier = ENMAdaptedDirectClassifier(config.hidden_size, class_config.num_labels)
|
| 33 |
+
elif class_config.adaptor_architecture == 'conv':
|
| 34 |
+
self.classifier = ENMAdaptedConvClassifier(config.hidden_size, class_config.num_labels, class_config.kernel_size, class_config.enm_embed_dim, class_config.num_layers)
|
| 35 |
+
elif class_config.adaptor_architecture == 'no-adaptor':
|
| 36 |
+
self.classifier = ENMNoAdaptorClassifier(config.hidden_size, class_config.num_labels)
|
| 37 |
+
else:
|
| 38 |
+
raise ValueError('Only attention, direct, conv and no-adaptor architectures are supported for the adaptor.')
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
# Initialize weights and apply final processing
|
| 42 |
+
self.post_init()
|
| 43 |
+
|
| 44 |
+
# Model parallel
|
| 45 |
+
self.model_parallel = False
|
| 46 |
+
self.device_map = None
|
| 47 |
+
|
| 48 |
+
def parallelize(self, device_map=None):
|
| 49 |
+
self.device_map = (
|
| 50 |
+
get_device_map(len(self.encoder.block), range(torch.cuda.device_count()))
|
| 51 |
+
if device_map is None
|
| 52 |
+
else device_map
|
| 53 |
+
)
|
| 54 |
+
assert_device_map(self.device_map, len(self.encoder.block))
|
| 55 |
+
self.encoder.parallelize(self.device_map)
|
| 56 |
+
self.classifier = self.classifier.to(self.encoder.first_device)
|
| 57 |
+
self.model_parallel = True
|
| 58 |
+
|
| 59 |
+
def deparallelize(self):
|
| 60 |
+
self.encoder.deparallelize()
|
| 61 |
+
self.encoder = self.encoder.to("cpu")
|
| 62 |
+
self.model_parallel = False
|
| 63 |
+
self.device_map = None
|
| 64 |
+
torch.cuda.empty_cache()
|
| 65 |
+
|
| 66 |
+
def get_input_embeddings(self):
|
| 67 |
+
return self.shared
|
| 68 |
+
|
| 69 |
+
def set_input_embeddings(self, new_embeddings):
|
| 70 |
+
self.shared = new_embeddings
|
| 71 |
+
self.encoder.set_input_embeddings(new_embeddings)
|
| 72 |
+
|
| 73 |
+
def get_encoder(self):
|
| 74 |
+
return self.encoder
|
| 75 |
+
|
| 76 |
+
def _prune_heads(self, heads_to_prune):
|
| 77 |
+
"""
|
| 78 |
+
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
|
| 79 |
+
class PreTrainedModel
|
| 80 |
+
"""
|
| 81 |
+
for layer, heads in heads_to_prune.items():
|
| 82 |
+
self.encoder.layer[layer].attention.prune_heads(heads)
|
| 83 |
+
|
| 84 |
+
def forward(
|
| 85 |
+
self,
|
| 86 |
+
enm_vals = None,
|
| 87 |
+
input_ids=None,
|
| 88 |
+
attention_mask=None,
|
| 89 |
+
head_mask=None,
|
| 90 |
+
inputs_embeds=None,
|
| 91 |
+
labels=None,
|
| 92 |
+
output_attentions=None,
|
| 93 |
+
output_hidden_states=None,
|
| 94 |
+
return_dict=None,
|
| 95 |
+
):
|
| 96 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 97 |
+
# import pdb; pdb.set_trace()
|
| 98 |
+
outputs = self.encoder(input_ids=input_ids,
|
| 99 |
+
attention_mask=attention_mask,
|
| 100 |
+
inputs_embeds=inputs_embeds,
|
| 101 |
+
head_mask=head_mask,
|
| 102 |
+
output_attentions=output_attentions,
|
| 103 |
+
output_hidden_states=output_hidden_states,
|
| 104 |
+
return_dict=return_dict,
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
sequence_output = outputs[0]
|
| 108 |
+
sequence_output = self.dropout(sequence_output)
|
| 109 |
+
#TODO: check the enm_vals are padded properly and check that the sequence limit (in the transformer) is indeed 512
|
| 110 |
+
logits = self.classifier(sequence_output, enm_vals, attention_mask)
|
| 111 |
+
|
| 112 |
+
if not return_dict:
|
| 113 |
+
output = (logits,) + outputs[2:]
|
| 114 |
+
return ((loss,) + output) if loss is not None else output
|
| 115 |
+
|
| 116 |
+
return TokenClassifierOutput(
|
| 117 |
+
#loss=loss,
|
| 118 |
+
logits=logits,
|
| 119 |
+
hidden_states=outputs.hidden_states,
|
| 120 |
+
attentions=outputs.attentions,
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
def PT5_classification_model(half_precision, class_config):
|
| 124 |
+
# Load PT5 and tokenizer
|
| 125 |
+
# possible to load the half preciion model (thanks to @pawel-rezo for pointing that out)
|
| 126 |
+
if not half_precision:
|
| 127 |
+
model = T5EncoderModel.from_pretrained("Rostlab/prot_t5_xl_uniref50", local_files_only=True)
|
| 128 |
+
tokenizer = T5Tokenizer.from_pretrained("Rostlab/prot_t5_xl_uniref50", local_files_only=True)
|
| 129 |
+
elif half_precision and torch.cuda.is_available():
|
| 130 |
+
tokenizer = T5Tokenizer.from_pretrained('Rostlab/prot_t5_xl_half_uniref50-enc', do_lower_case=False, local_files_only=True)
|
| 131 |
+
model = T5EncoderModel.from_pretrained("Rostlab/prot_t5_xl_half_uniref50-enc", torch_dtype=torch.float16, local_files_only=True).to(torch.device('cuda'))
|
| 132 |
+
else:
|
| 133 |
+
raise ValueError('Half precision can be run on GPU only.')
|
| 134 |
+
|
| 135 |
+
# Create new Classifier model with PT5 dimensions
|
| 136 |
+
class_model=T5EncoderForTokenClassification(model.config,class_config)
|
| 137 |
+
|
| 138 |
+
# Set encoder and embedding weights to checkpoint weights
|
| 139 |
+
class_model.shared=model.shared
|
| 140 |
+
class_model.encoder=model.encoder
|
| 141 |
+
|
| 142 |
+
# Delete the checkpoint model
|
| 143 |
+
model=class_model
|
| 144 |
+
del class_model
|
| 145 |
+
|
| 146 |
+
# Print number of trainable parameters
|
| 147 |
+
model_parameters = filter(lambda p: p.requires_grad, model.parameters())
|
| 148 |
+
params = sum([np.prod(p.size()) for p in model_parameters])
|
| 149 |
+
print("ProtT5_Classfier\nTrainable Parameter: "+ str(params))
|
| 150 |
+
|
| 151 |
+
# Add model modification lora
|
| 152 |
+
config = LoRAConfig('configs/lora_config.yaml')
|
| 153 |
+
|
| 154 |
+
# Add LoRA layers
|
| 155 |
+
model = modify_with_lora(model, config)
|
| 156 |
+
|
| 157 |
+
# Freeze Embeddings and Encoder (except LoRA)
|
| 158 |
+
for (param_name, param) in model.shared.named_parameters():
|
| 159 |
+
param.requires_grad = False
|
| 160 |
+
for (param_name, param) in model.encoder.named_parameters():
|
| 161 |
+
param.requires_grad = False
|
| 162 |
+
|
| 163 |
+
for (param_name, param) in model.named_parameters():
|
| 164 |
+
if re.fullmatch(config.trainable_param_names, param_name):
|
| 165 |
+
param.requires_grad = True
|
| 166 |
+
|
| 167 |
+
# Print trainable Parameter
|
| 168 |
+
model_parameters = filter(lambda p: p.requires_grad, model.parameters())
|
| 169 |
+
params = sum([np.prod(p.size()) for p in model_parameters])
|
| 170 |
+
print("ProtT5_LoRA_Classfier\nTrainable Parameter: "+ str(params) + "\n")
|
| 171 |
+
|
| 172 |
+
return model, tokenizer
|
__pycache__/T5_encoder_per_token.cpython-313.pyc
ADDED
|
Binary file (10 kB). View file
|
|
|
__pycache__/T5_encoder_per_token.cpython-39.pyc
ADDED
|
Binary file (5.93 kB). View file
|
|
|
__pycache__/enm_adaptor_heads.cpython-313.pyc
ADDED
|
Binary file (6.14 kB). View file
|
|
|
__pycache__/enm_adaptor_heads.cpython-39.pyc
ADDED
|
Binary file (3.57 kB). View file
|
|
|
enm_adaptor_heads.py
ADDED
|
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
|
| 5 |
+
class ENMAdaptedAttentionClassifier(nn.Module):
|
| 6 |
+
def __init__(self, seq_embedding_dim, out_dim, enm_embed_dim, num_att_heads):
|
| 7 |
+
super(ENMAdaptedAttentionClassifier, self).__init__()
|
| 8 |
+
self.embedding = nn.Linear(1, enm_embed_dim)
|
| 9 |
+
self.enm_attention = nn.MultiheadAttention(enm_embed_dim, num_att_heads)
|
| 10 |
+
self.layer_norm = nn.LayerNorm(enm_embed_dim)
|
| 11 |
+
self.enm_adaptor = nn.Linear(enm_embed_dim, seq_embedding_dim)
|
| 12 |
+
self.adapted_classifier = nn.Linear(2*seq_embedding_dim, out_dim)
|
| 13 |
+
|
| 14 |
+
def forward(self, seq_embedding, enm_input):
|
| 15 |
+
enm_input = enm_input.transpose(0, 1) # Transpose to shape (N, B, E) for MultiheadAttention
|
| 16 |
+
enm_input = enm_input.unsqueeze(-1) # Add a dimension for the embedding
|
| 17 |
+
enm_input_embedded = self.embedding(enm_input)
|
| 18 |
+
enm_att, _ = self.enm_attention(enm_input_embedded, enm_input_embedded, enm_input_embedded)
|
| 19 |
+
enm_att = enm_att.transpose(0, 1) # Transpose back to shape (B, N, E)
|
| 20 |
+
enm_att = self.layer_norm(enm_att + enm_input.transpose(0, 1))
|
| 21 |
+
enm_embedding = self.enm_adaptor(enm_att)
|
| 22 |
+
# import pdb; pdb.set_trace()
|
| 23 |
+
combined_embedding = torch.cat((seq_embedding, enm_embedding), dim=-1)
|
| 24 |
+
logits = self.adapted_classifier(combined_embedding)
|
| 25 |
+
return logits
|
| 26 |
+
|
| 27 |
+
class ENMAdaptedConvClassifier(nn.Module):
|
| 28 |
+
def __init__(self, seq_embedding_dim, out_dim, kernel_size, enm_embedding_dim, num_layers):
|
| 29 |
+
super(ENMAdaptedConvClassifier, self).__init__()
|
| 30 |
+
layers = []
|
| 31 |
+
self.conv1 = nn.Conv1d(1, enm_embedding_dim, kernel_size=kernel_size, padding=(kernel_size-1)//2)
|
| 32 |
+
layers.append(self.conv1)
|
| 33 |
+
layers.append(nn.ReLU())
|
| 34 |
+
for i in range(num_layers-1):
|
| 35 |
+
layers.append(nn.Conv1d(enm_embedding_dim, enm_embedding_dim, kernel_size=kernel_size, padding=(kernel_size-1)//2))
|
| 36 |
+
layers.append(nn.ReLU())
|
| 37 |
+
self.conv_net = nn.Sequential(*layers)
|
| 38 |
+
self.adapted_classifier = nn.Linear(seq_embedding_dim+1, out_dim)
|
| 39 |
+
|
| 40 |
+
def forward(self, seq_embedding, enm_input, attention_mask=None):
|
| 41 |
+
enm_input = torch.nan_to_num(enm_input, nan=0.0)
|
| 42 |
+
enm_input = enm_input.unsqueeze(1)
|
| 43 |
+
enm_input = enm_input.to(seq_embedding.device)
|
| 44 |
+
conv_out = self.conv_net(enm_input)
|
| 45 |
+
enm_embedding = conv_out.transpose(1,2)
|
| 46 |
+
|
| 47 |
+
if attention_mask is not None:
|
| 48 |
+
# Use attention_mask to ignore padded elements
|
| 49 |
+
mask = attention_mask.unsqueeze(-1).float()
|
| 50 |
+
enm_embedding = enm_embedding * mask
|
| 51 |
+
# Compute mean over non-padded elements
|
| 52 |
+
|
| 53 |
+
enm_embedding = enm_embedding.mean(dim=-1).unsqueeze(-1)
|
| 54 |
+
# enm_embedding = enm_embedding.sum(dim=2)/ mask.sum(dim=2).clamp(min=1e-9)
|
| 55 |
+
else:
|
| 56 |
+
raise ValueError('We actually want to provide the mask.')
|
| 57 |
+
enm_embedding = torch.mean(enm_embedding, dim=1)
|
| 58 |
+
|
| 59 |
+
# enm_embedding = enm_embedding.unsqueeze(1).expand(-1, seq_embedding.size(1), -1)
|
| 60 |
+
combined_embedding = torch.cat((seq_embedding, enm_embedding), dim=-1)
|
| 61 |
+
logits = self.adapted_classifier(combined_embedding)
|
| 62 |
+
return logits
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
class ENMAdaptedDirectClassifier(nn.Module):
|
| 67 |
+
def __init__(self, seq_embedding_dim, out_dim):
|
| 68 |
+
super(ENMAdaptedDirectClassifier, self).__init__()
|
| 69 |
+
self.adapted_classifier = nn.Linear(seq_embedding_dim+1, out_dim)
|
| 70 |
+
|
| 71 |
+
def forward(self, seq_embedding, enm_input):
|
| 72 |
+
enm_input = enm_input.unsqueeze(-1)
|
| 73 |
+
combined_embedding = torch.cat((seq_embedding, enm_input), dim=-1)
|
| 74 |
+
logits = self.adapted_classifier(combined_embedding)
|
| 75 |
+
return logits
|
| 76 |
+
|
| 77 |
+
class ENMNoAdaptorClassifier(nn.Module):
|
| 78 |
+
def __init__(self, seq_embedding_dim, out_dim):
|
| 79 |
+
super(ENMNoAdaptorClassifier, self).__init__()
|
| 80 |
+
self.adapted_classifier = nn.Linear(seq_embedding_dim, out_dim)
|
| 81 |
+
|
| 82 |
+
def forward(self, seq_embedding, enm_input, attention_mask=None):
|
| 83 |
+
_ = enm_input #ignoring enm_input
|
| 84 |
+
logits = self.adapted_classifier(seq_embedding)
|
| 85 |
+
return logits
|
weights/.gitkeep
ADDED
|
File without changes
|
weights/flexpert_3d_weights.bin
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:3cbc7a6bed15e92cc6b5f65b947c3c838e46e5815f7cbd57f54bbc19741558e6
|
| 3 |
+
size 4843266070
|
weights/flexpert_seq_weights.bin
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:ca9ffd974154597e372c30faa728e4c61c5811fc98a148af66d31dfe2b5c0061
|
| 3 |
+
size 4842603885
|