Update modeling_esm_plusplus.py
Browse files- modeling_esm_plusplus.py +1 -11
modeling_esm_plusplus.py
CHANGED
|
@@ -567,14 +567,6 @@ class PreTrainedESMplusplusModel(PreTrainedModel):
|
|
| 567 |
attention_mask = attention_mask.unsqueeze(-1)
|
| 568 |
return (x * attention_mask).max(dim=1).values
|
| 569 |
|
| 570 |
-
def min_pooling(self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
| 571 |
-
"""Apply min pooling to sequence outputs."""
|
| 572 |
-
if attention_mask is None:
|
| 573 |
-
return x.min(dim=1).values
|
| 574 |
-
else:
|
| 575 |
-
attention_mask = attention_mask.unsqueeze(-1)
|
| 576 |
-
return (x * attention_mask).min(dim=1).values
|
| 577 |
-
|
| 578 |
def cls_pooling(self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
| 579 |
"""Apply cls pooling to sequence outputs."""
|
| 580 |
return x[:, 0, :]
|
|
@@ -633,13 +625,11 @@ class PreTrainedESMplusplusModel(PreTrainedModel):
|
|
| 633 |
|
| 634 |
def get_embeddings(residue_embeddings: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
| 635 |
if full_embeddings:
|
| 636 |
-
return residue_embeddings
|
| 637 |
elif pooling_type == 'mean':
|
| 638 |
return self.mean_pooling(residue_embeddings, attention_mask)
|
| 639 |
elif pooling_type == 'max':
|
| 640 |
return self.max_pooling(residue_embeddings, attention_mask)
|
| 641 |
-
elif pooling_type == 'min':
|
| 642 |
-
return self.min_pooling(residue_embeddings, attention_mask)
|
| 643 |
elif pooling_type == 'cls':
|
| 644 |
return self.cls_pooling(residue_embeddings, attention_mask)
|
| 645 |
else:
|
|
|
|
| 567 |
attention_mask = attention_mask.unsqueeze(-1)
|
| 568 |
return (x * attention_mask).max(dim=1).values
|
| 569 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 570 |
def cls_pooling(self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
| 571 |
"""Apply cls pooling to sequence outputs."""
|
| 572 |
return x[:, 0, :]
|
|
|
|
| 625 |
|
| 626 |
def get_embeddings(residue_embeddings: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
| 627 |
if full_embeddings:
|
| 628 |
+
return residue_embeddings, attention_mask
|
| 629 |
elif pooling_type == 'mean':
|
| 630 |
return self.mean_pooling(residue_embeddings, attention_mask)
|
| 631 |
elif pooling_type == 'max':
|
| 632 |
return self.max_pooling(residue_embeddings, attention_mask)
|
|
|
|
|
|
|
| 633 |
elif pooling_type == 'cls':
|
| 634 |
return self.cls_pooling(residue_embeddings, attention_mask)
|
| 635 |
else:
|