hemantn commited on
Commit
2e88277
·
1 Parent(s): e75c857

Implement working restore method in HFAbRestore class

Browse files
Files changed (1) hide show
  1. adapter.py +34 -0
adapter.py CHANGED
@@ -559,6 +559,40 @@ class HFAbRestore(AbRestore):
559
  return output.last_hidden_state
560
  return output
561
  return model_call
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
562
 
563
  def add_angle_brackets(seq):
564
  # Assumes input is 'VH|VL' or 'VH|' or '|VL'
 
559
  return output.last_hidden_state
560
  return output
561
  return model_call
562
+
563
+ def restore(self, seqs, align=False, **kwargs):
564
+ """Restore masked residues in antibody sequences."""
565
+ if isinstance(seqs, str):
566
+ seqs = [seqs]
567
+
568
+ n_seqs = len(seqs)
569
+
570
+ if align:
571
+ # For now, skip alignment as it requires ANARCI
572
+ print("WARNING: Alignment not implemented, skipping...")
573
+ pass
574
+
575
+ # Tokenize sequences using original interface
576
+ tokens = self.tokenizer(seqs, pad=True, w_extra_tkns=False, device=self.used_device)
577
+
578
+ # Get predictions for amino acids (indices 1-20)
579
+ predictions = self.AbLang(tokens)[:,:,1:21]
580
+
581
+ # Find predicted tokens and replace mask tokens
582
+ predicted_tokens = torch.max(predictions, -1).indices + 1
583
+ restored_tokens = torch.where(tokens==23, predicted_tokens, tokens)
584
+
585
+ # Decode back to sequences using original tokenizer
586
+ restored_seqs = self.tokenizer(restored_tokens, mode="decode")
587
+
588
+ # Handle paired sequences format
589
+ if n_seqs < len(restored_seqs):
590
+ restored_seqs = [f"{h}|{l}".replace('-','') for h,l in zip(restored_seqs[:n_seqs], restored_seqs[n_seqs:])]
591
+ seqs = [f"{h}|{l}" for h,l in zip(seqs[:n_seqs], seqs[n_seqs:])]
592
+
593
+ # Apply final formatting
594
+ from extra_utils import res_to_seq
595
+ return np.array([res_to_seq(seq, 'restore') for seq in np.c_[restored_seqs, np.vectorize(len)(seqs)]])
596
 
597
  def add_angle_brackets(seq):
598
  # Assumes input is 'VH|VL' or 'VH|' or '|VL'