Update modeling_protst.py
Browse files- modeling_protst.py +5 -1
modeling_protst.py
CHANGED
|
@@ -55,6 +55,8 @@ class BertForPubMed(BertPreTrainedModel):
|
|
| 55 |
self.text_mlp = ProtSTHead(config)
|
| 56 |
self.word_mlp = ProtSTHead(config)
|
| 57 |
|
|
|
|
|
|
|
| 58 |
def forward(
|
| 59 |
self,
|
| 60 |
input_ids: Optional[torch.Tensor] = None,
|
|
@@ -111,7 +113,7 @@ class EsmForProteinRepresentation(EsmPreTrainedModel):
|
|
| 111 |
self.protein_mlp = ProtSTHead(config)
|
| 112 |
self.residue_mlp = ProtSTHead(config)
|
| 113 |
|
| 114 |
-
self.
|
| 115 |
|
| 116 |
def forward(
|
| 117 |
self,
|
|
@@ -163,6 +165,8 @@ class EsmForProteinPropertyPrediction(EsmPreTrainedModel):
|
|
| 163 |
self.model = EsmForProteinRepresentation(config)
|
| 164 |
self.classifier = ProtSTHead(config, out_dim=config.num_labels)
|
| 165 |
|
|
|
|
|
|
|
| 166 |
def forward(
|
| 167 |
self,
|
| 168 |
input_ids: Optional[torch.LongTensor] = None,
|
|
|
|
| 55 |
self.text_mlp = ProtSTHead(config)
|
| 56 |
self.word_mlp = ProtSTHead(config)
|
| 57 |
|
| 58 |
+
self.post_init() # NOTE
|
| 59 |
+
|
| 60 |
def forward(
|
| 61 |
self,
|
| 62 |
input_ids: Optional[torch.Tensor] = None,
|
|
|
|
| 113 |
self.protein_mlp = ProtSTHead(config)
|
| 114 |
self.residue_mlp = ProtSTHead(config)
|
| 115 |
|
| 116 |
+
self.post_init() # NOTE
|
| 117 |
|
| 118 |
def forward(
|
| 119 |
self,
|
|
|
|
| 165 |
self.model = EsmForProteinRepresentation(config)
|
| 166 |
self.classifier = ProtSTHead(config, out_dim=config.num_labels)
|
| 167 |
|
| 168 |
+
self.post_init() # NOTE
|
| 169 |
+
|
| 170 |
def forward(
|
| 171 |
self,
|
| 172 |
input_ids: Optional[torch.LongTensor] = None,
|