hemantn commited on
Commit
acf4b07
·
1 Parent(s): eb85d4e

Fix restore method to properly handle masked tokens

Browse files
Files changed (1) hide show
  1. adapter.py +42 -5
adapter.py CHANGED
@@ -561,14 +561,51 @@ class HFAbRestore(AbRestore):
561
 
562
  def restore(self, seqs, align=False, **kwargs):
563
  """Restore masked residues in antibody sequences."""
564
- # Use the original AbLang2 restore logic
565
- # This should work correctly like it did before
566
  if isinstance(seqs, str):
567
  seqs = [seqs]
568
 
569
- # Use the original restore logic from the parent class
570
- # The AbRestore class should have the working implementation
571
- return super().restore(seqs, align=align, **kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
572
 
573
  def add_angle_brackets(seq):
574
  # Assumes input is 'VH|VL' or 'VH|' or '|VL'
 
561
 
562
  def restore(self, seqs, align=False, **kwargs):
563
  """Restore masked residues in antibody sequences."""
 
 
564
  if isinstance(seqs, str):
565
  seqs = [seqs]
566
 
567
+ restored_seqs = []
568
+ for seq in seqs:
569
+ # Check if sequence has masked tokens
570
+ if '*' not in seq:
571
+ # No masked tokens, return as-is
572
+ restored_seqs.append(seq)
573
+ continue
574
+
575
+ # Tokenize the sequence
576
+ input_ids = self.tokenizer([seq], pad=True, w_extra_tkns=False, device=self.used_device)
577
+
578
+ # Find masked tokens (assuming * is the mask token)
579
+ mask_token_id = self.tokenizer.mask_token_id
580
+ masked_positions = (input_ids[0] == mask_token_id).nonzero(as_tuple=True)[0]
581
+
582
+ if len(masked_positions) == 0:
583
+ # No masked tokens found, return original
584
+ restored_seqs.append(seq)
585
+ continue
586
+
587
+ # Get predictions for masked positions
588
+ with torch.no_grad():
589
+ output = self._hf_model(input_ids)
590
+ if hasattr(output, 'last_hidden_state'):
591
+ logits = output.last_hidden_state
592
+ else:
593
+ logits = output
594
+
595
+ # Get predictions for masked positions
596
+ masked_logits = logits[0, masked_positions]
597
+ predicted_tokens = torch.argmax(masked_logits, dim=-1)
598
+
599
+ # Replace masked tokens with predicted tokens
600
+ restored_input_ids = input_ids[0].clone()
601
+ restored_input_ids[masked_positions] = predicted_tokens
602
+
603
+ # Decode back to sequence
604
+ restored_seq = self.tokenizer.tokenizer.decode(restored_input_ids, skip_special_tokens=True)
605
+ restored_seq = restored_seq.replace(' ', '')
606
+ restored_seqs.append(restored_seq)
607
+
608
+ return np.array(restored_seqs) if len(restored_seqs) > 1 else restored_seqs[0]
609
 
610
  def add_angle_brackets(seq):
611
  # Assumes input is 'VH|VL' or 'VH|' or '|VL'