gabrielbianchin commited on
Commit
cf9f7d5
·
1 Parent(s): 0a36690
README.md ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ESM2 T36
2
+
3
+ ```python
4
+ from transformers import AutoTokenizer, AutoModel
5
+ import torch
6
+
7
+ tokenizer = AutoTokenizer.from_pretrained('facebook/esm2_t36_3B_UR50D')
8
+ encoder = AutoModel.from_pretrained("facebook/esm2_t36_3B_UR50D")
9
+ protenrich = AutoModel.from_pretrained("SaeedLab/ProtEnrich-ESM2-T36", trust_remote_code=True)
10
+
11
+ seqs = ["MKTFFVLLL"]
12
+ seqs = [" ".join(i) for i in seqs]
13
+ inputs = tokenizer(seqs, return_tensors="pt", padding=True)
14
+
15
+ with torch.no_grad():
16
+ outputs = encoder(**inputs)
17
+ pooled = outputs.last_hidden_state[0, 1:-1].mean(axis=0)
18
+ enriched = protenrich(pooled)
19
+
20
+ print('H enrich:', enriched.h_enrich)
21
+ print('H anchor:', enriched.h_anchor)
22
+ print('H algn:', enriched.h_algn)
23
+ print('Structure:', enriched.struct)
24
+ print('Dynamics:', enriched.dyn)
25
+ ```
config.json ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "alpha_max": 0.3,
3
+ "auto_map": {
4
+ "AutoConfig": "configuration_protenrich.ProtEnrichConfig",
5
+ "AutoModel": "modeling_protenrich.ProtEnrichModel",
6
+ "AutoModelForSequenceClassification": "modeling_protenrich.ProtEnrichForSequenceClassification"
7
+ },
8
+ "architectures": [
9
+ "ProtEnrichModel"
10
+ ],
11
+ "dtype": "float32",
12
+ "dyn_dim": 20,
13
+ "embed_dim": 1024,
14
+ "model_type": "protenrich",
15
+ "project_dim": 256,
16
+ "seq_dim": 2560,
17
+ "struct_dim": 1024,
18
+ "transformers_version": "4.57.3"
19
+ }
configuration_protenrich.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+
3
+ class ProtEnrichConfig(PretrainedConfig):
4
+ model_type = "protenrich"
5
+
6
+ def __init__(
7
+ self,
8
+ seq_dim : int = 2560,
9
+ struct_dim: int = 1024,
10
+ dyn_dim: int = 20,
11
+ embed_dim: int = 1024,
12
+ project_dim: int = 256,
13
+ alpha_max: float = 0.3,
14
+ num_labels: int = 2,
15
+ **kwargs):
16
+ super().__init__(**kwargs)
17
+ self.seq_dim = seq_dim
18
+ self.struct_dim = struct_dim
19
+ self.dyn_dim = dyn_dim
20
+ self.embed_dim = embed_dim
21
+ self.project_dim = project_dim
22
+ self.alpha_max = alpha_max
23
+ self.num_labels = num_labels
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a43df16fca77cc8426c8d9ecc33b6c564910c4e987e4dbeb7ebbbd8d7a002514
3
+ size 72656740
modeling_protenrich.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from dataclasses import dataclass
5
+ from typing import Optional, Tuple
6
+ import torch
7
+ from transformers.utils import ModelOutput
8
+ from transformers import PreTrainedModel
9
+ from transformers.modeling_outputs import BaseModelOutputWithPooling, SequenceClassifierOutput
10
+ from transformers import AutoModel
11
+
12
+ from .configuration_protenrich import ProtEnrichConfig
13
+
14
+ @dataclass
15
+ class ProtEnrichModelOutput(ModelOutput):
16
+ h_enrich: torch.FloatTensor = None
17
+ h_anchor: Optional[torch.FloatTensor] = None
18
+ h_algn: Optional[torch.FloatTensor] = None
19
+ struct: Optional[torch.FloatTensor] = None
20
+ dyn: Optional[torch.FloatTensor] = None
21
+
22
+ class MLPEncoder(nn.Module):
23
+ def __init__(self, in_dim, out_dim, hidden_dim=1024, n_layers=2, dropout=0.1):
24
+ super().__init__()
25
+ layers = []
26
+ d = in_dim
27
+ for _ in range(n_layers - 1):
28
+ layers += [
29
+ nn.Linear(d, hidden_dim),
30
+ nn.LayerNorm(hidden_dim),
31
+ nn.GELU(),
32
+ nn.Dropout(dropout),
33
+ ]
34
+ d = hidden_dim
35
+ layers.append(nn.Linear(d, out_dim))
36
+ self.net = nn.Sequential(*layers)
37
+
38
+ def forward(self, x):
39
+ return self.net(x)
40
+
41
+ class ProtEnrichModel(PreTrainedModel):
42
+ config_class = ProtEnrichConfig
43
+ base_model_prefix = "protenrich"
44
+
45
+ def __init__(self, config: ProtEnrichConfig):
46
+ super().__init__(config)
47
+
48
+ self.seq_anchor = MLPEncoder(config.seq_dim, config.embed_dim)
49
+ self.seq_algn = MLPEncoder(config.seq_dim, config.embed_dim)
50
+ self.struct_encoder = MLPEncoder(config.struct_dim, config.embed_dim)
51
+ self.dyn_encoder = MLPEncoder(config.dyn_dim, config.embed_dim)
52
+
53
+ for p in self.struct_encoder.parameters():
54
+ p.requires_grad = False
55
+ for p in self.dyn_encoder.parameters():
56
+ p.requires_grad = False
57
+
58
+ self.seq_projector = nn.Linear(config.embed_dim, config.project_dim)
59
+ self.struct_projector = nn.Linear(config.embed_dim, config.project_dim)
60
+ self.dyn_projector = nn.Linear(config.embed_dim, config.project_dim)
61
+
62
+ self.seq_decoder = MLPEncoder(config.embed_dim, config.seq_dim)
63
+ self.struct_decoder = MLPEncoder(config.embed_dim, config.struct_dim)
64
+ self.dyn_decoder = MLPEncoder(config.embed_dim, config.dyn_dim)
65
+
66
+ self.alpha_logit = nn.Parameter(torch.tensor(-2.0))
67
+ self.alpha_max = config.alpha_max
68
+
69
+ self.norm_anchor = nn.LayerNorm(config.embed_dim)
70
+ self.norm_algn = nn.LayerNorm(config.embed_dim)
71
+
72
+ self.post_init()
73
+
74
+ def forward(self, seq: torch.Tensor, return_dict: Optional[bool] = None):
75
+
76
+ h_anchor = self.norm_anchor(self.seq_anchor(seq))
77
+ h_algn = self.norm_algn(self.seq_algn(seq))
78
+
79
+ struct = self.struct_decoder(h_algn)
80
+ dyn = self.dyn_decoder(h_algn)
81
+
82
+ alpha = torch.sigmoid(self.alpha_logit) * self.alpha_max
83
+ h_enrich = h_anchor + alpha * h_algn
84
+
85
+ return ProtEnrichModelOutput(
86
+ h_enrich=h_enrich,
87
+ h_anchor=h_anchor,
88
+ h_algn=h_algn,
89
+ struct=struct,
90
+ dyn=dyn,
91
+ )
92
+
93
+ class ProtEnrichForSequenceClassification(PreTrainedModel):
94
+ config_class = ProtEnrichConfig
95
+
96
+ def __init__(self, config: ProtEnrichConfig):
97
+ super().__init__(config)
98
+
99
+ self.num_labels = config.num_labels
100
+
101
+ self.protenrich = ProtEnrichModel(config)
102
+ self.classifier = nn.Linear(config.embed_dim, config.num_labels)
103
+
104
+ self.post_init()
105
+
106
+ def forward(self, seq: torch.Tensor, labels: Optional[torch.Tensor] = None, return_dict: Optional[bool] = None):
107
+
108
+ outputs = self.protenrich(seq=seq, return_dict=return_dict)
109
+ pooled = outputs.h_enrich
110
+
111
+ logits = self.classifier(pooled)
112
+
113
+ loss = None
114
+ if labels is not None:
115
+ loss_fct = nn.CrossEntropyLoss()
116
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
117
+
118
+ return SequenceClassifierOutput(
119
+ loss=loss,
120
+ logits=logits,
121
+ hidden_states=pooled,
122
+ )