SkywalkerLu commited on
Commit
bd7b816
·
verified ·
1 Parent(s): c3d85ec

Update modeling_transhla2.py

Browse files
Files changed (1) hide show
  1. modeling_transhla2.py +45 -45
modeling_transhla2.py CHANGED
@@ -1,45 +1,45 @@
1
- import torch
2
- import torch.nn as nn
3
- from transformers import PreTrainedModel, PretrainedConfig
4
-
5
- from peft import LoraConfig, get_peft_model, TaskType
6
- from transformers import EsmModel
7
-
8
- class TransHLA2Config(PretrainedConfig):
9
- model_type = "transhla2"
10
- def __init__(self, d_model=480, **kwargs):
11
- super().__init__(**kwargs)
12
- self.d_model = d_model
13
- # 可加入其它自定义参数
14
-
15
- class TransHLA2(PreTrainedModel):
16
- config_class = TransHLA2Config
17
- def __init__(self, config):
18
- super().__init__(config)
19
- self.model_name_or_path = "facebook/esm2_t12_35M_UR50D"
20
- self.tokenizer_name_or_path = "facebook/esm2_t12_35M_UR50D"
21
- self.peft_config = LoraConfig(
22
- target_modules=['query', 'out_proj', 'value', 'key', 'dense', 'regression'],
23
- task_type=TaskType.FEATURE_EXTRACTION,
24
- inference_mode=False, r=8, lora_alpha=32, lora_dropout=0.1
25
- )
26
- self.d_model = config.d_model
27
- self.esm = EsmModel.from_pretrained(self.model_name_or_path)
28
- self.lora_esm = get_peft_model(self.esm, self.peft_config)
29
- self.fc_task = nn.Sequential(
30
- nn.Linear(d_model, d_model // 4),
31
- nn.BatchNorm1d(d_model // 4),
32
- nn.Dropout(0.2),
33
- nn.SiLU(),
34
- nn.Linear(d_model // 4, 32),
35
- nn.BatchNorm1d(32),
36
- )
37
- self.classifier = nn.Linear(32, 2)
38
-
39
- def forward(self, x_in):
40
- lora_outputs = self.lora_esm(x_in)
41
- last_hidden_state = lora_outputs.last_hidden_state
42
- out_linear = last_hidden_state.mean(dim=1)
43
- H = self.fc_task(out_linear)
44
- output = self.classifier(H)
45
- return output, last_hidden_state
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from transformers import PreTrainedModel, PretrainedConfig
4
+
5
+ from peft import LoraConfig, get_peft_model, TaskType
6
+ from transformers import EsmModel
7
+
8
+ class TransHLA2Config(PretrainedConfig):
9
+ model_type = "transhla2"
10
+ def __init__(self, d_model=480, **kwargs):
11
+ super().__init__(**kwargs)
12
+ self.d_model = d_model
13
+ # 可加入其它自定义参数
14
+
15
+ class TransHLA2(PreTrainedModel):
16
+ config_class = TransHLA2Config
17
+ def __init__(self, config):
18
+ super().__init__(config)
19
+ self.model_name_or_path = "facebook/esm2_t12_35M_UR50D"
20
+ self.tokenizer_name_or_path = "facebook/esm2_t12_35M_UR50D"
21
+ self.peft_config = LoraConfig(
22
+ target_modules=['query', 'out_proj', 'value', 'key', 'dense', 'regression'],
23
+ task_type=TaskType.FEATURE_EXTRACTION,
24
+ inference_mode=False, r=8, lora_alpha=32, lora_dropout=0.1
25
+ )
26
+ d_model = config.d_model
27
+ self.esm = EsmModel.from_pretrained(self.model_name_or_path)
28
+ self.lora_esm = get_peft_model(self.esm, self.peft_config)
29
+ self.fc_task = nn.Sequential(
30
+ nn.Linear(d_model, d_model // 4),
31
+ nn.BatchNorm1d(d_model // 4),
32
+ nn.Dropout(0.2),
33
+ nn.SiLU(),
34
+ nn.Linear(d_model // 4, 32),
35
+ nn.BatchNorm1d(32),
36
+ )
37
+ self.classifier = nn.Linear(32, 2)
38
+
39
+ def forward(self, x_in):
40
+ lora_outputs = self.lora_esm(x_in)
41
+ last_hidden_state = lora_outputs.last_hidden_state
42
+ out_linear = last_hidden_state.mean(dim=1)
43
+ H = self.fc_task(out_linear)
44
+ output = self.classifier(H)
45
+ return output, last_hidden_state