Add missing rescoding method implementation to adapter class
Browse files- adapter.py +22 -0
adapter.py
CHANGED
|
@@ -776,6 +776,28 @@ class AbLang2PairedHuggingFaceAdapter(AbEncoding, AbRestore, AbAlignment, AbScor
|
|
| 776 |
|
| 777 |
return seq_embeddings.cpu().numpy()
|
| 778 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 779 |
def confidence(self, seqs, **kwargs):
|
| 780 |
"""Confidence calculation - match original ablang2 implementation by excluding all special tokens from loss."""
|
| 781 |
# Format input: join VH and VL with '|'
|
|
|
|
| 776 |
|
| 777 |
return seq_embeddings.cpu().numpy()
|
| 778 |
|
| 779 |
+
def rescoding(self, seqs, align=False, **kwargs):
|
| 780 |
+
"""Residue specific representations - returns 480-dimensional embeddings for each residue."""
|
| 781 |
+
# Format input: join VH and VL with '|'
|
| 782 |
+
formatted_seqs = []
|
| 783 |
+
for s in seqs:
|
| 784 |
+
if isinstance(s, (list, tuple)):
|
| 785 |
+
formatted_seqs.append('|'.join(s))
|
| 786 |
+
else:
|
| 787 |
+
formatted_seqs.append(s)
|
| 788 |
+
|
| 789 |
+
# Get embeddings using the model
|
| 790 |
+
embeddings = self._encode_sequences(formatted_seqs)
|
| 791 |
+
|
| 792 |
+
# Return residue-level embeddings
|
| 793 |
+
# embeddings shape: [batch_size, seq_len, hidden_size]
|
| 794 |
+
if len(embeddings.shape) == 3:
|
| 795 |
+
# Convert to numpy and return as list of arrays for each sequence
|
| 796 |
+
embeddings_np = embeddings.cpu().numpy()
|
| 797 |
+
return [embeddings_np[i] for i in range(embeddings_np.shape[0])]
|
| 798 |
+
else:
|
| 799 |
+
return embeddings.cpu().numpy()
|
| 800 |
+
|
| 801 |
def confidence(self, seqs, **kwargs):
|
| 802 |
"""Confidence calculation - match original ablang2 implementation by excluding all special tokens from loss."""
|
| 803 |
# Format input: join VH and VL with '|'
|