hemantn commited on
Commit
eb85d4e
·
1 Parent(s): 55a4139

Revert restore method to use original AbLang2 logic

Browse files
Files changed (1) hide show
  1. adapter.py +5 -39
adapter.py CHANGED
@@ -561,48 +561,14 @@ class HFAbRestore(AbRestore):
561
 
562
  def restore(self, seqs, align=False, **kwargs):
563
  """Restore masked residues in antibody sequences."""
 
 
564
  if isinstance(seqs, str):
565
  seqs = [seqs]
566
 
567
- restored_seqs = []
568
- for seq in seqs:
569
- # Tokenize the sequence using the correct interface
570
- input_ids = self.tokenizer([seq], pad=True, w_extra_tkns=False, device=self.used_device)
571
-
572
- # Find masked tokens
573
- mask_token_id = self.tokenizer.mask_token_id
574
- masked_positions = (input_ids[0] == mask_token_id).nonzero(as_tuple=True)[0]
575
-
576
- if len(masked_positions) == 0:
577
- # No masked tokens, return original sequence
578
- restored_seqs.append(seq)
579
- continue
580
-
581
- # Get predictions for masked positions
582
- with torch.no_grad():
583
- output = self._hf_model(input_ids)
584
- if hasattr(output, 'last_hidden_state'):
585
- logits = output.last_hidden_state
586
- else:
587
- logits = output
588
-
589
- # Get predictions for masked positions
590
- masked_logits = logits[0, masked_positions] # [num_masked, vocab_size]
591
- predicted_tokens = torch.argmax(masked_logits, dim=-1)
592
-
593
- # Replace masked tokens with predicted tokens
594
- restored_input_ids = input_ids[0].clone()
595
- restored_input_ids[masked_positions] = predicted_tokens
596
-
597
- # Decode back to sequence using the original tokenizer and remove spaces
598
- restored_seq = self.tokenizer.tokenizer.decode(restored_input_ids, skip_special_tokens=True)
599
- # Remove spaces that might be added during decoding
600
- restored_seq = restored_seq.replace(' ', '')
601
- restored_seqs.append(restored_seq)
602
-
603
- # For now, return restored sequences without alignment
604
- # Proper ANARCI-based alignment requires full implementation of alignment functions
605
- return np.array(restored_seqs) if len(restored_seqs) > 1 else restored_seqs[0]
606
 
607
  def add_angle_brackets(seq):
608
  # Assumes input is 'VH|VL' or 'VH|' or '|VL'
 
561
 
562
  def restore(self, seqs, align=False, **kwargs):
563
  """Restore masked residues in antibody sequences."""
564
+ # Use the original AbLang2 restore logic
565
+ # This should work correctly like it did before
566
  if isinstance(seqs, str):
567
  seqs = [seqs]
568
 
569
+ # Use the original restore logic from the parent class
570
+ # The AbRestore class should have the working implementation
571
+ return super().restore(seqs, align=align, **kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
572
 
573
  def add_angle_brackets(seq):
574
  # Assumes input is 'VH|VL' or 'VH|' or '|VL'