Honzus24 commited on
Commit
f18fbd7
·
verified ·
1 Parent(s): 3e749c5

Upload folder using huggingface_hub

Browse files
ESM_per_token.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers.models.esm.modeling_esm import EsmPreTrainedModel, EsmModel
2
+ import torch
3
+ import torch.nn as nn
4
+ from typing import Optional, Union, Tuple
5
+ from transformers.models.auto.modeling_auto import AutoModel
6
+ from transformers.models.auto.tokenization_auto import AutoTokenizer
7
+ from torch.nn import MSELoss
8
+ from transformers.modeling_outputs import TokenClassifierOutput
9
+ import numpy as np
10
+ import re
11
+ from utils.lora_utils import LoRAConfig, modify_with_lora
12
+ from models.enm_adaptor_heads import (
13
+ ENMAdaptedAttentionClassifier, ENMAdaptedDirectClassifier,
14
+ ENMAdaptedConvClassifier, ENMNoAdaptorClassifier
15
+ )
16
+ from peft import LoraConfig, inject_adapter_in_model
17
+
18
+ class EsmForTokenRegression(EsmPreTrainedModel):
19
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
20
+ _keys_to_ignore_on_load_missing = [r"position_ids"]
21
+
22
+ def __init__(self, config, class_config):
23
+ super().__init__(config)
24
+ self.num_labels = config.num_labels
25
+ self.add_pearson_loss = class_config.add_pearson_loss
26
+ self.add_sse_loss = class_config.add_sse_loss
27
+
28
+ self.esm = EsmModel(config, add_pooling_layer=False)
29
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
30
+
31
+ if class_config.adaptor_architecture == 'attention':
32
+ self.classifier = ENMAdaptedAttentionClassifier(
33
+ config.hidden_size,
34
+ class_config.num_labels,
35
+ class_config.enm_embed_dim,
36
+ class_config.enm_att_heads
37
+ )
38
+ elif class_config.adaptor_architecture == 'direct':
39
+ self.classifier = ENMAdaptedDirectClassifier(
40
+ config.hidden_size,
41
+ class_config.num_labels
42
+ )
43
+ elif class_config.adaptor_architecture == 'conv':
44
+ self.classifier = ENMAdaptedConvClassifier(
45
+ config.hidden_size,
46
+ class_config.num_labels,
47
+ class_config.kernel_size,
48
+ class_config.enm_embed_dim,
49
+ class_config.num_layers
50
+ )
51
+ elif class_config.adaptor_architecture == 'no-adaptor':
52
+ self.classifier = ENMNoAdaptorClassifier(
53
+ config.hidden_size,
54
+ class_config.num_labels
55
+ )
56
+ else:
57
+ raise ValueError('Only attention, direct, conv and no-adaptor architectures are supported.')
58
+
59
+ self.init_weights()
60
+
61
+ def forward(
62
+ self,
63
+ enm_vals=None,
64
+ input_ids: Optional[torch.LongTensor] = None,
65
+ attention_mask: Optional[torch.Tensor] = None,
66
+ position_ids: Optional[torch.LongTensor] = None,
67
+ head_mask: Optional[torch.Tensor] = None,
68
+ inputs_embeds: Optional[torch.FloatTensor] = None,
69
+ labels: Optional[torch.FloatTensor] = None,
70
+ output_attentions: Optional[bool] = None,
71
+ output_hidden_states: Optional[bool] = None,
72
+ return_dict: Optional[bool] = None,
73
+ ) -> Union[Tuple, TokenClassifierOutput]:
74
+
75
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
76
+
77
+ outputs = self.esm(
78
+ input_ids,
79
+ attention_mask=attention_mask,
80
+ position_ids=position_ids,
81
+ head_mask=head_mask,
82
+ inputs_embeds=inputs_embeds,
83
+ output_attentions=output_attentions,
84
+ output_hidden_states=output_hidden_states,
85
+ return_dict=return_dict,
86
+ )
87
+
88
+ sequence_output = outputs[0]
89
+ sequence_output = self.dropout(sequence_output)
90
+
91
+ logits = self.classifier(sequence_output, enm_vals, attention_mask)
92
+
93
+ if not return_dict:
94
+ output = (logits,) + outputs[2:]
95
+ return output
96
+
97
+ return TokenClassifierOutput(
98
+ logits=logits,
99
+ hidden_states=outputs.hidden_states,
100
+ attentions=outputs.attentions,
101
+ )
102
+
103
+ def ESM_classification_model(half_precision, class_config, lora_config):
104
+ # Load ESM and tokenizer
105
+ if not half_precision:
106
+ model = EsmModel.from_pretrained("facebook/esm2_t36_3B_UR50D")
107
+ tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t36_3B_UR50D")
108
+ elif half_precision and torch.cuda.is_available():
109
+ model = EsmModel.from_pretrained("facebook/esm2_t36_3B_UR50D", torch_dtype=torch.float16).to(torch.device('cuda'))
110
+ tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t36_3B_UR50D")
111
+ else:
112
+ raise ValueError('Half precision can be run on GPU only.')
113
+
114
+ # Create new Classifier model with ESM dimensions
115
+ class_model = EsmForTokenRegression(model.config, class_config)
116
+
117
+ # Set encoder weights to checkpoint weights
118
+ class_model.esm = model
119
+
120
+ # Delete the checkpoint model
121
+ del model
122
+
123
+ # Print number of trainable parameters
124
+ model_parameters = filter(lambda p: p.requires_grad, class_model.parameters())
125
+ params = sum([np.prod(p.size()) for p in model_parameters])
126
+ print("ESM_Classifier\nTrainable Parameter: " + str(params))
127
+
128
+ # Add model modification lora
129
+ esm_lora_peft_config = LoraConfig(
130
+ r=4, lora_alpha=1, bias="all", target_modules=["query","key","value","dense"]
131
+ )
132
+
133
+ # Add LoRA layers
134
+ class_model.esm = inject_adapter_in_model(esm_lora_peft_config, class_model.esm)
135
+
136
+ # Freeze Encoder (except LoRA)
137
+ for (param_name, param) in class_model.esm.named_parameters():
138
+ param.requires_grad = False
139
+
140
+ for (param_name, param) in class_model.esm.named_parameters():
141
+ if re.fullmatch(".*lora.*", param_name): #".*layer_norm.*|.*lora_[ab].*"
142
+ param.requires_grad = True
143
+ if re.fullmatch(".*layer_norm.*", param_name): #".*layer_norm.*|.*lora_[ab].*"
144
+ param.requires_grad = True
145
+ # Print trainable Parameter
146
+ model_parameters = filter(lambda p: p.requires_grad, class_model.parameters())
147
+ params = sum([np.prod(p.size()) for p in model_parameters])
148
+ print("ESM_LoRA_Classifier\nTrainable Parameter: " + str(params) + "\n")
149
+
150
+ return class_model, tokenizer
T5_encoder_per_token.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn as nn
4
+ import copy
5
+ import re
6
+ from transformers import T5Config, T5PreTrainedModel, T5EncoderModel, T5Tokenizer
7
+ from transformers.models.t5.modeling_t5 import T5Stack
8
+ from transformers.modeling_outputs import TokenClassifierOutput
9
+ from transformers.utils.model_parallel_utils import assert_device_map, get_device_map
10
+ from models.enm_adaptor_heads import ENMAdaptedAttentionClassifier, ENMAdaptedDirectClassifier, ENMAdaptedConvClassifier, ENMNoAdaptorClassifier
11
+ from utils.lora_utils import LoRAConfig, modify_with_lora
12
+
13
+ class T5EncoderForTokenClassification(T5PreTrainedModel):
14
+
15
+ def __init__(self, config: T5Config, class_config):
16
+ super().__init__(config)
17
+ self.num_labels = class_config.num_labels
18
+ self.config = config
19
+ self.add_pearson_loss = class_config.add_pearson_loss
20
+ self.add_sse_loss = class_config.add_sse_loss
21
+ self.shared = nn.Embedding(config.vocab_size, config.d_model)
22
+
23
+ encoder_config = copy.deepcopy(config)
24
+ encoder_config.use_cache = False
25
+ encoder_config.is_encoder_decoder = False
26
+ self.encoder = T5Stack(encoder_config, self.shared)
27
+
28
+ self.dropout = nn.Dropout(class_config.dropout_rate)
29
+ if class_config.adaptor_architecture == 'attention':
30
+ self.classifier = ENMAdaptedAttentionClassifier(config.hidden_size, class_config.num_labels, class_config.enm_embed_dim, class_config.enm_att_heads) #nn.Linear(config.hidden_size, class_config.num_labels)
31
+ elif class_config.adaptor_architecture == 'direct':
32
+ self.classifier = ENMAdaptedDirectClassifier(config.hidden_size, class_config.num_labels)
33
+ elif class_config.adaptor_architecture == 'conv':
34
+ self.classifier = ENMAdaptedConvClassifier(config.hidden_size, class_config.num_labels, class_config.kernel_size, class_config.enm_embed_dim, class_config.num_layers)
35
+ elif class_config.adaptor_architecture == 'no-adaptor':
36
+ self.classifier = ENMNoAdaptorClassifier(config.hidden_size, class_config.num_labels)
37
+ else:
38
+ raise ValueError('Only attention, direct, conv and no-adaptor architectures are supported for the adaptor.')
39
+
40
+
41
+ # Initialize weights and apply final processing
42
+ self.post_init()
43
+
44
+ # Model parallel
45
+ self.model_parallel = False
46
+ self.device_map = None
47
+
48
+ def parallelize(self, device_map=None):
49
+ self.device_map = (
50
+ get_device_map(len(self.encoder.block), range(torch.cuda.device_count()))
51
+ if device_map is None
52
+ else device_map
53
+ )
54
+ assert_device_map(self.device_map, len(self.encoder.block))
55
+ self.encoder.parallelize(self.device_map)
56
+ self.classifier = self.classifier.to(self.encoder.first_device)
57
+ self.model_parallel = True
58
+
59
+ def deparallelize(self):
60
+ self.encoder.deparallelize()
61
+ self.encoder = self.encoder.to("cpu")
62
+ self.model_parallel = False
63
+ self.device_map = None
64
+ torch.cuda.empty_cache()
65
+
66
+ def get_input_embeddings(self):
67
+ return self.shared
68
+
69
+ def set_input_embeddings(self, new_embeddings):
70
+ self.shared = new_embeddings
71
+ self.encoder.set_input_embeddings(new_embeddings)
72
+
73
+ def get_encoder(self):
74
+ return self.encoder
75
+
76
+ def _prune_heads(self, heads_to_prune):
77
+ """
78
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
79
+ class PreTrainedModel
80
+ """
81
+ for layer, heads in heads_to_prune.items():
82
+ self.encoder.layer[layer].attention.prune_heads(heads)
83
+
84
+ def forward(
85
+ self,
86
+ enm_vals = None,
87
+ input_ids=None,
88
+ attention_mask=None,
89
+ head_mask=None,
90
+ inputs_embeds=None,
91
+ labels=None,
92
+ output_attentions=None,
93
+ output_hidden_states=None,
94
+ return_dict=None,
95
+ ):
96
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
97
+ # import pdb; pdb.set_trace()
98
+ outputs = self.encoder(input_ids=input_ids,
99
+ attention_mask=attention_mask,
100
+ inputs_embeds=inputs_embeds,
101
+ head_mask=head_mask,
102
+ output_attentions=output_attentions,
103
+ output_hidden_states=output_hidden_states,
104
+ return_dict=return_dict,
105
+ )
106
+
107
+ sequence_output = outputs[0]
108
+ sequence_output = self.dropout(sequence_output)
109
+ #TODO: check the enm_vals are padded properly and check that the sequence limit (in the transformer) is indeed 512
110
+ logits = self.classifier(sequence_output, enm_vals, attention_mask)
111
+
112
+ if not return_dict:
113
+ output = (logits,) + outputs[2:]
114
+ return ((loss,) + output) if loss is not None else output
115
+
116
+ return TokenClassifierOutput(
117
+ #loss=loss,
118
+ logits=logits,
119
+ hidden_states=outputs.hidden_states,
120
+ attentions=outputs.attentions,
121
+ )
122
+
123
+ def PT5_classification_model(half_precision, class_config):
124
+ # Load PT5 and tokenizer
125
+ # possible to load the half preciion model (thanks to @pawel-rezo for pointing that out)
126
+ if not half_precision:
127
+ model = T5EncoderModel.from_pretrained("Rostlab/prot_t5_xl_uniref50", local_files_only=True)
128
+ tokenizer = T5Tokenizer.from_pretrained("Rostlab/prot_t5_xl_uniref50", local_files_only=True)
129
+ elif half_precision and torch.cuda.is_available():
130
+ tokenizer = T5Tokenizer.from_pretrained('Rostlab/prot_t5_xl_half_uniref50-enc', do_lower_case=False, local_files_only=True)
131
+ model = T5EncoderModel.from_pretrained("Rostlab/prot_t5_xl_half_uniref50-enc", torch_dtype=torch.float16, local_files_only=True).to(torch.device('cuda'))
132
+ else:
133
+ raise ValueError('Half precision can be run on GPU only.')
134
+
135
+ # Create new Classifier model with PT5 dimensions
136
+ class_model=T5EncoderForTokenClassification(model.config,class_config)
137
+
138
+ # Set encoder and embedding weights to checkpoint weights
139
+ class_model.shared=model.shared
140
+ class_model.encoder=model.encoder
141
+
142
+ # Delete the checkpoint model
143
+ model=class_model
144
+ del class_model
145
+
146
+ # Print number of trainable parameters
147
+ model_parameters = filter(lambda p: p.requires_grad, model.parameters())
148
+ params = sum([np.prod(p.size()) for p in model_parameters])
149
+ print("ProtT5_Classfier\nTrainable Parameter: "+ str(params))
150
+
151
+ # Add model modification lora
152
+ config = LoRAConfig('configs/lora_config.yaml')
153
+
154
+ # Add LoRA layers
155
+ model = modify_with_lora(model, config)
156
+
157
+ # Freeze Embeddings and Encoder (except LoRA)
158
+ for (param_name, param) in model.shared.named_parameters():
159
+ param.requires_grad = False
160
+ for (param_name, param) in model.encoder.named_parameters():
161
+ param.requires_grad = False
162
+
163
+ for (param_name, param) in model.named_parameters():
164
+ if re.fullmatch(config.trainable_param_names, param_name):
165
+ param.requires_grad = True
166
+
167
+ # Print trainable Parameter
168
+ model_parameters = filter(lambda p: p.requires_grad, model.parameters())
169
+ params = sum([np.prod(p.size()) for p in model_parameters])
170
+ print("ProtT5_LoRA_Classfier\nTrainable Parameter: "+ str(params) + "\n")
171
+
172
+ return model, tokenizer
__pycache__/T5_encoder_per_token.cpython-313.pyc ADDED
Binary file (10 kB). View file
 
__pycache__/T5_encoder_per_token.cpython-39.pyc ADDED
Binary file (5.93 kB). View file
 
__pycache__/enm_adaptor_heads.cpython-313.pyc ADDED
Binary file (6.14 kB). View file
 
__pycache__/enm_adaptor_heads.cpython-39.pyc ADDED
Binary file (3.57 kB). View file
 
enm_adaptor_heads.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ class ENMAdaptedAttentionClassifier(nn.Module):
6
+ def __init__(self, seq_embedding_dim, out_dim, enm_embed_dim, num_att_heads):
7
+ super(ENMAdaptedAttentionClassifier, self).__init__()
8
+ self.embedding = nn.Linear(1, enm_embed_dim)
9
+ self.enm_attention = nn.MultiheadAttention(enm_embed_dim, num_att_heads)
10
+ self.layer_norm = nn.LayerNorm(enm_embed_dim)
11
+ self.enm_adaptor = nn.Linear(enm_embed_dim, seq_embedding_dim)
12
+ self.adapted_classifier = nn.Linear(2*seq_embedding_dim, out_dim)
13
+
14
+ def forward(self, seq_embedding, enm_input):
15
+ enm_input = enm_input.transpose(0, 1) # Transpose to shape (N, B, E) for MultiheadAttention
16
+ enm_input = enm_input.unsqueeze(-1) # Add a dimension for the embedding
17
+ enm_input_embedded = self.embedding(enm_input)
18
+ enm_att, _ = self.enm_attention(enm_input_embedded, enm_input_embedded, enm_input_embedded)
19
+ enm_att = enm_att.transpose(0, 1) # Transpose back to shape (B, N, E)
20
+ enm_att = self.layer_norm(enm_att + enm_input.transpose(0, 1))
21
+ enm_embedding = self.enm_adaptor(enm_att)
22
+ # import pdb; pdb.set_trace()
23
+ combined_embedding = torch.cat((seq_embedding, enm_embedding), dim=-1)
24
+ logits = self.adapted_classifier(combined_embedding)
25
+ return logits
26
+
27
+ class ENMAdaptedConvClassifier(nn.Module):
28
+ def __init__(self, seq_embedding_dim, out_dim, kernel_size, enm_embedding_dim, num_layers):
29
+ super(ENMAdaptedConvClassifier, self).__init__()
30
+ layers = []
31
+ self.conv1 = nn.Conv1d(1, enm_embedding_dim, kernel_size=kernel_size, padding=(kernel_size-1)//2)
32
+ layers.append(self.conv1)
33
+ layers.append(nn.ReLU())
34
+ for i in range(num_layers-1):
35
+ layers.append(nn.Conv1d(enm_embedding_dim, enm_embedding_dim, kernel_size=kernel_size, padding=(kernel_size-1)//2))
36
+ layers.append(nn.ReLU())
37
+ self.conv_net = nn.Sequential(*layers)
38
+ self.adapted_classifier = nn.Linear(seq_embedding_dim+1, out_dim)
39
+
40
+ def forward(self, seq_embedding, enm_input, attention_mask=None):
41
+ enm_input = torch.nan_to_num(enm_input, nan=0.0)
42
+ enm_input = enm_input.unsqueeze(1)
43
+ enm_input = enm_input.to(seq_embedding.device)
44
+ conv_out = self.conv_net(enm_input)
45
+ enm_embedding = conv_out.transpose(1,2)
46
+
47
+ if attention_mask is not None:
48
+ # Use attention_mask to ignore padded elements
49
+ mask = attention_mask.unsqueeze(-1).float()
50
+ enm_embedding = enm_embedding * mask
51
+ # Compute mean over non-padded elements
52
+
53
+ enm_embedding = enm_embedding.mean(dim=-1).unsqueeze(-1)
54
+ # enm_embedding = enm_embedding.sum(dim=2)/ mask.sum(dim=2).clamp(min=1e-9)
55
+ else:
56
+ raise ValueError('We actually want to provide the mask.')
57
+ enm_embedding = torch.mean(enm_embedding, dim=1)
58
+
59
+ # enm_embedding = enm_embedding.unsqueeze(1).expand(-1, seq_embedding.size(1), -1)
60
+ combined_embedding = torch.cat((seq_embedding, enm_embedding), dim=-1)
61
+ logits = self.adapted_classifier(combined_embedding)
62
+ return logits
63
+
64
+
65
+
66
+ class ENMAdaptedDirectClassifier(nn.Module):
67
+ def __init__(self, seq_embedding_dim, out_dim):
68
+ super(ENMAdaptedDirectClassifier, self).__init__()
69
+ self.adapted_classifier = nn.Linear(seq_embedding_dim+1, out_dim)
70
+
71
+ def forward(self, seq_embedding, enm_input):
72
+ enm_input = enm_input.unsqueeze(-1)
73
+ combined_embedding = torch.cat((seq_embedding, enm_input), dim=-1)
74
+ logits = self.adapted_classifier(combined_embedding)
75
+ return logits
76
+
77
+ class ENMNoAdaptorClassifier(nn.Module):
78
+ def __init__(self, seq_embedding_dim, out_dim):
79
+ super(ENMNoAdaptorClassifier, self).__init__()
80
+ self.adapted_classifier = nn.Linear(seq_embedding_dim, out_dim)
81
+
82
+ def forward(self, seq_embedding, enm_input, attention_mask=None):
83
+ _ = enm_input #ignoring enm_input
84
+ logits = self.adapted_classifier(seq_embedding)
85
+ return logits
weights/.gitkeep ADDED
File without changes
weights/flexpert_3d_weights.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3cbc7a6bed15e92cc6b5f65b947c3c838e46e5815f7cbd57f54bbc19741558e6
3
+ size 4843266070
weights/flexpert_seq_weights.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ca9ffd974154597e372c30faa728e4c61c5811fc98a148af66d31dfe2b5c0061
3
+ size 4842603885