hemantn commited on
Commit
b1f86f0
·
1 Parent(s): 3d5a035

Add missing rescoding method implementation to adapter class

Browse files
Files changed (1) hide show
  1. adapter.py +22 -0
adapter.py CHANGED
@@ -776,6 +776,28 @@ class AbLang2PairedHuggingFaceAdapter(AbEncoding, AbRestore, AbAlignment, AbScor
776
 
777
  return seq_embeddings.cpu().numpy()
778
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
779
  def confidence(self, seqs, **kwargs):
780
  """Confidence calculation - match original ablang2 implementation by excluding all special tokens from loss."""
781
  # Format input: join VH and VL with '|'
 
776
 
777
  return seq_embeddings.cpu().numpy()
778
 
779
+ def rescoding(self, seqs, align=False, **kwargs):
780
+ """Residue specific representations - returns 480-dimensional embeddings for each residue."""
781
+ # Format input: join VH and VL with '|'
782
+ formatted_seqs = []
783
+ for s in seqs:
784
+ if isinstance(s, (list, tuple)):
785
+ formatted_seqs.append('|'.join(s))
786
+ else:
787
+ formatted_seqs.append(s)
788
+
789
+ # Get embeddings using the model
790
+ embeddings = self._encode_sequences(formatted_seqs)
791
+
792
+ # Return residue-level embeddings
793
+ # embeddings shape: [batch_size, seq_len, hidden_size]
794
+ if len(embeddings.shape) == 3:
795
+ # Convert to numpy and return as list of arrays for each sequence
796
+ embeddings_np = embeddings.cpu().numpy()
797
+ return [embeddings_np[i] for i in range(embeddings_np.shape[0])]
798
+ else:
799
+ return embeddings.cpu().numpy()
800
+
801
  def confidence(self, seqs, **kwargs):
802
  """Confidence calculation - match original ablang2 implementation by excluding all special tokens from loss."""
803
  # Format input: join VH and VL with '|'