hemantn commited on
Commit
dec4a70
·
1 Parent(s): b1f86f0

Add missing likelihood method implementation to adapter class

Browse files
Files changed (1) hide show
  1. adapter.py +19 -0
adapter.py CHANGED
@@ -798,6 +798,25 @@ class AbLang2PairedHuggingFaceAdapter(AbEncoding, AbRestore, AbAlignment, AbScor
798
  else:
799
  return embeddings.cpu().numpy()
800
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
801
  def confidence(self, seqs, **kwargs):
802
  """Confidence calculation - match original ablang2 implementation by excluding all special tokens from loss."""
803
  # Format input: join VH and VL with '|'
 
798
  else:
799
  return embeddings.cpu().numpy()
800
 
801
+ def likelihood(self, seqs, align=False, stepwise_masking=False, **kwargs):
802
+ """Likelihood of mutations - returns logits for each amino acid at each position."""
803
+ # Format input: join VH and VL with '|'
804
+ formatted_seqs = []
805
+ for s in seqs:
806
+ if isinstance(s, (list, tuple)):
807
+ formatted_seqs.append('|'.join(s))
808
+ else:
809
+ formatted_seqs.append(s)
810
+
811
+ # Get logits
812
+ if stepwise_masking:
813
+ logits = self._predict_logits_with_step_masking(formatted_seqs)
814
+ else:
815
+ logits = self._predict_logits(formatted_seqs)
816
+
817
+ # Return logits as numpy array
818
+ return logits.cpu().numpy()
819
+
820
  def confidence(self, seqs, **kwargs):
821
  """Confidence calculation - match original ablang2 implementation by excluding all special tokens from loss."""
822
  # Format input: join VH and VL with '|'