Fix restore method to properly handle masked tokens
Browse files- adapter.py +42 -5
adapter.py
CHANGED
|
@@ -561,14 +561,51 @@ class HFAbRestore(AbRestore):
|
|
| 561 |
|
| 562 |
def restore(self, seqs, align=False, **kwargs):
|
| 563 |
"""Restore masked residues in antibody sequences."""
|
| 564 |
-
# Use the original AbLang2 restore logic
|
| 565 |
-
# This should work correctly like it did before
|
| 566 |
if isinstance(seqs, str):
|
| 567 |
seqs = [seqs]
|
| 568 |
|
| 569 |
-
|
| 570 |
-
|
| 571 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 572 |
|
| 573 |
def add_angle_brackets(seq):
|
| 574 |
# Assumes input is 'VH|VL' or 'VH|' or '|VL'
|
|
|
|
| 561 |
|
| 562 |
def restore(self, seqs, align=False, **kwargs):
|
| 563 |
"""Restore masked residues in antibody sequences."""
|
|
|
|
|
|
|
| 564 |
if isinstance(seqs, str):
|
| 565 |
seqs = [seqs]
|
| 566 |
|
| 567 |
+
restored_seqs = []
|
| 568 |
+
for seq in seqs:
|
| 569 |
+
# Check if sequence has masked tokens
|
| 570 |
+
if '*' not in seq:
|
| 571 |
+
# No masked tokens, return as-is
|
| 572 |
+
restored_seqs.append(seq)
|
| 573 |
+
continue
|
| 574 |
+
|
| 575 |
+
# Tokenize the sequence
|
| 576 |
+
input_ids = self.tokenizer([seq], pad=True, w_extra_tkns=False, device=self.used_device)
|
| 577 |
+
|
| 578 |
+
# Find masked tokens (assuming * is the mask token)
|
| 579 |
+
mask_token_id = self.tokenizer.mask_token_id
|
| 580 |
+
masked_positions = (input_ids[0] == mask_token_id).nonzero(as_tuple=True)[0]
|
| 581 |
+
|
| 582 |
+
if len(masked_positions) == 0:
|
| 583 |
+
# No masked tokens found, return original
|
| 584 |
+
restored_seqs.append(seq)
|
| 585 |
+
continue
|
| 586 |
+
|
| 587 |
+
# Get predictions for masked positions
|
| 588 |
+
with torch.no_grad():
|
| 589 |
+
output = self._hf_model(input_ids)
|
| 590 |
+
if hasattr(output, 'last_hidden_state'):
|
| 591 |
+
logits = output.last_hidden_state
|
| 592 |
+
else:
|
| 593 |
+
logits = output
|
| 594 |
+
|
| 595 |
+
# Get predictions for masked positions
|
| 596 |
+
masked_logits = logits[0, masked_positions]
|
| 597 |
+
predicted_tokens = torch.argmax(masked_logits, dim=-1)
|
| 598 |
+
|
| 599 |
+
# Replace masked tokens with predicted tokens
|
| 600 |
+
restored_input_ids = input_ids[0].clone()
|
| 601 |
+
restored_input_ids[masked_positions] = predicted_tokens
|
| 602 |
+
|
| 603 |
+
# Decode back to sequence
|
| 604 |
+
restored_seq = self.tokenizer.tokenizer.decode(restored_input_ids, skip_special_tokens=True)
|
| 605 |
+
restored_seq = restored_seq.replace(' ', '')
|
| 606 |
+
restored_seqs.append(restored_seq)
|
| 607 |
+
|
| 608 |
+
return np.array(restored_seqs) if len(restored_seqs) > 1 else restored_seqs[0]
|
| 609 |
|
| 610 |
def add_angle_brackets(seq):
|
| 611 |
# Assumes input is 'VH|VL' or 'VH|' or '|VL'
|