hemantn commited on
Commit
0e5736d
·
1 Parent(s): 2b25439

Remove HFAbRestore.restore override to use original AbRestore.restore method

Browse files
Files changed (1) hide show
  1. adapter.py +0 -29
adapter.py CHANGED
@@ -558,35 +558,6 @@ class HFAbRestore(AbRestore):
558
  return output.last_hidden_state
559
  return output
560
  return model_call
561
-
562
- def restore(self, seqs, align=False, **kwargs):
563
- """Restore masked residues in antibody sequences."""
564
- if isinstance(seqs, str):
565
- seqs = [seqs]
566
-
567
- # Use the original AbLang2 restore logic
568
- # Tokenize sequences using original interface
569
- tokens = self.tokenizer(seqs, pad=True, w_extra_tkns=False, device=self.used_device)
570
-
571
- # Get predictions for amino acids (indices 1-20)
572
- predictions = self.AbLang(tokens)[:,:,1:21]
573
-
574
- # Find predicted tokens and replace mask tokens
575
- predicted_tokens = torch.max(predictions, -1).indices + 1
576
- restored_tokens = torch.where(tokens==23, predicted_tokens, tokens)
577
-
578
- # Decode back to sequences using original tokenizer
579
- restored_seqs = self.tokenizer(restored_tokens, mode="decode")
580
-
581
- # Handle paired sequences format
582
- n_seqs = len(seqs)
583
- if n_seqs < len(restored_seqs):
584
- restored_seqs = [f"{h}|{l}".replace('-','') for h,l in zip(restored_seqs[:n_seqs], restored_seqs[n_seqs:])]
585
- seqs = [f"{h}|{l}" for h,l in zip(seqs[:n_seqs], seqs[n_seqs:])]
586
-
587
- # Apply final formatting
588
- from extra_utils import res_to_seq
589
- return np.array([res_to_seq(seq, 'restore') for seq in np.c_[restored_seqs, np.vectorize(len)(seqs)]])
590
 
591
  def add_angle_brackets(seq):
592
  # Assumes input is 'VH|VL' or 'VH|' or '|VL'
 
558
  return output.last_hidden_state
559
  return output
560
  return model_call
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
561
 
562
  def add_angle_brackets(seq):
563
  # Assumes input is 'VH|VL' or 'VH|' or '|VL'