Add missing likelihood method implementation to adapter class
Browse files- adapter.py +19 -0
adapter.py
CHANGED
|
@@ -798,6 +798,25 @@ class AbLang2PairedHuggingFaceAdapter(AbEncoding, AbRestore, AbAlignment, AbScor
|
|
| 798 |
else:
|
| 799 |
return embeddings.cpu().numpy()
|
| 800 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 801 |
def confidence(self, seqs, **kwargs):
|
| 802 |
"""Confidence calculation - match original ablang2 implementation by excluding all special tokens from loss."""
|
| 803 |
# Format input: join VH and VL with '|'
|
|
|
|
| 798 |
else:
|
| 799 |
return embeddings.cpu().numpy()
|
| 800 |
|
| 801 |
+
def likelihood(self, seqs, align=False, stepwise_masking=False, **kwargs):
|
| 802 |
+
"""Likelihood of mutations - returns logits for each amino acid at each position."""
|
| 803 |
+
# Format input: join VH and VL with '|'
|
| 804 |
+
formatted_seqs = []
|
| 805 |
+
for s in seqs:
|
| 806 |
+
if isinstance(s, (list, tuple)):
|
| 807 |
+
formatted_seqs.append('|'.join(s))
|
| 808 |
+
else:
|
| 809 |
+
formatted_seqs.append(s)
|
| 810 |
+
|
| 811 |
+
# Get logits
|
| 812 |
+
if stepwise_masking:
|
| 813 |
+
logits = self._predict_logits_with_step_masking(formatted_seqs)
|
| 814 |
+
else:
|
| 815 |
+
logits = self._predict_logits(formatted_seqs)
|
| 816 |
+
|
| 817 |
+
# Return logits as numpy array
|
| 818 |
+
return logits.cpu().numpy()
|
| 819 |
+
|
| 820 |
def confidence(self, seqs, **kwargs):
|
| 821 |
"""Confidence calculation - match original ablang2 implementation by excluding all special tokens from loss."""
|
| 822 |
# Format input: join VH and VL with '|'
|