File size: 6,063 Bytes
7968cb0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
from transformers.models.esm.modeling_esm import EsmPreTrainedModel, EsmModel
import torch
import torch.nn as nn
from typing import Optional, Union, Tuple
from transformers.models.auto.modeling_auto import AutoModel
from transformers.models.auto.tokenization_auto import AutoTokenizer
from torch.nn import MSELoss
from transformers.modeling_outputs import TokenClassifierOutput
import numpy as np
import re
from utils.lora_utils import LoRAConfig, modify_with_lora
from models.enm_adaptor_heads import (
    ENMAdaptedAttentionClassifier, ENMAdaptedDirectClassifier, 
    ENMAdaptedConvClassifier, ENMNoAdaptorClassifier
)
from peft import LoraConfig, inject_adapter_in_model

class EsmForTokenRegression(EsmPreTrainedModel):
    _keys_to_ignore_on_load_unexpected = [r"pooler"]
    _keys_to_ignore_on_load_missing = [r"position_ids"]

    def __init__(self, config, class_config):
        super().__init__(config)
        self.num_labels = config.num_labels
        self.add_pearson_loss = class_config.add_pearson_loss
        self.add_sse_loss = class_config.add_sse_loss

        self.esm = EsmModel(config, add_pooling_layer=False)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        
        if class_config.adaptor_architecture == 'attention':
            self.classifier = ENMAdaptedAttentionClassifier(
                config.hidden_size, 
                class_config.num_labels, 
                class_config.enm_embed_dim, 
                class_config.enm_att_heads
            )
        elif class_config.adaptor_architecture == 'direct':
            self.classifier = ENMAdaptedDirectClassifier(
                config.hidden_size, 
                class_config.num_labels
            )
        elif class_config.adaptor_architecture == 'conv':
            self.classifier = ENMAdaptedConvClassifier(
                config.hidden_size, 
                class_config.num_labels, 
                class_config.kernel_size, 
                class_config.enm_embed_dim, 
                class_config.num_layers
            )
        elif class_config.adaptor_architecture == 'no-adaptor':
            self.classifier = ENMNoAdaptorClassifier(
                config.hidden_size, 
                class_config.num_labels
            )
        else:
            raise ValueError('Only attention, direct, conv and no-adaptor architectures are supported.')

        self.init_weights()

    def forward(
        self,
        enm_vals=None,
        input_ids: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.FloatTensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple, TokenClassifierOutput]:

        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        outputs = self.esm(
            input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        sequence_output = outputs[0]
        sequence_output = self.dropout(sequence_output)
        
        logits = self.classifier(sequence_output, enm_vals, attention_mask)

        if not return_dict:
            output = (logits,) + outputs[2:]
            return output

        return TokenClassifierOutput(
            logits=logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )

def ESM_classification_model(half_precision, class_config, lora_config):
    # Load ESM and tokenizer
    if not half_precision:
        model = EsmModel.from_pretrained("facebook/esm2_t36_3B_UR50D")
        tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t36_3B_UR50D")
    elif half_precision and torch.cuda.is_available():
        model = EsmModel.from_pretrained("facebook/esm2_t36_3B_UR50D", torch_dtype=torch.float16).to(torch.device('cuda'))
        tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t36_3B_UR50D")
    else:
        raise ValueError('Half precision can be run on GPU only.')
    
    # Create new Classifier model with ESM dimensions
    class_model = EsmForTokenRegression(model.config, class_config)
    
    # Set encoder weights to checkpoint weights
    class_model.esm = model
    
    # Delete the checkpoint model
    del model
    
    # Print number of trainable parameters
    model_parameters = filter(lambda p: p.requires_grad, class_model.parameters())
    params = sum([np.prod(p.size()) for p in model_parameters])
    print("ESM_Classifier\nTrainable Parameter: " + str(params))
    
    # Add model modification lora
    esm_lora_peft_config = LoraConfig(
        r=4, lora_alpha=1, bias="all", target_modules=["query","key","value","dense"]
    )
    
    # Add LoRA layers
    class_model.esm = inject_adapter_in_model(esm_lora_peft_config, class_model.esm) 
    
    # Freeze Encoder (except LoRA)
    for (param_name, param) in class_model.esm.named_parameters():
        param.requires_grad = False
    
    for (param_name, param) in class_model.esm.named_parameters():
        if re.fullmatch(".*lora.*", param_name): #".*layer_norm.*|.*lora_[ab].*"
            param.requires_grad = True
        if re.fullmatch(".*layer_norm.*", param_name): #".*layer_norm.*|.*lora_[ab].*"
            param.requires_grad = True
    # Print trainable Parameter
    model_parameters = filter(lambda p: p.requires_grad, class_model.parameters())
    params = sum([np.prod(p.size()) for p in model_parameters])
    print("ESM_LoRA_Classifier\nTrainable Parameter: " + str(params) + "\n")
    
    return class_model, tokenizer