Implement working restore method in HFAbRestore class
Browse files- adapter.py +34 -0
adapter.py
CHANGED
|
@@ -559,6 +559,40 @@ class HFAbRestore(AbRestore):
|
|
| 559 |
return output.last_hidden_state
|
| 560 |
return output
|
| 561 |
return model_call
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 562 |
|
| 563 |
def add_angle_brackets(seq):
|
| 564 |
# Assumes input is 'VH|VL' or 'VH|' or '|VL'
|
|
|
|
| 559 |
return output.last_hidden_state
|
| 560 |
return output
|
| 561 |
return model_call
|
| 562 |
+
|
| 563 |
+
def restore(self, seqs, align=False, **kwargs):
|
| 564 |
+
"""Restore masked residues in antibody sequences."""
|
| 565 |
+
if isinstance(seqs, str):
|
| 566 |
+
seqs = [seqs]
|
| 567 |
+
|
| 568 |
+
n_seqs = len(seqs)
|
| 569 |
+
|
| 570 |
+
if align:
|
| 571 |
+
# For now, skip alignment as it requires ANARCI
|
| 572 |
+
print("WARNING: Alignment not implemented, skipping...")
|
| 573 |
+
pass
|
| 574 |
+
|
| 575 |
+
# Tokenize sequences using original interface
|
| 576 |
+
tokens = self.tokenizer(seqs, pad=True, w_extra_tkns=False, device=self.used_device)
|
| 577 |
+
|
| 578 |
+
# Get predictions for amino acids (indices 1-20)
|
| 579 |
+
predictions = self.AbLang(tokens)[:,:,1:21]
|
| 580 |
+
|
| 581 |
+
# Find predicted tokens and replace mask tokens
|
| 582 |
+
predicted_tokens = torch.max(predictions, -1).indices + 1
|
| 583 |
+
restored_tokens = torch.where(tokens==23, predicted_tokens, tokens)
|
| 584 |
+
|
| 585 |
+
# Decode back to sequences using original tokenizer
|
| 586 |
+
restored_seqs = self.tokenizer(restored_tokens, mode="decode")
|
| 587 |
+
|
| 588 |
+
# Handle paired sequences format
|
| 589 |
+
if n_seqs < len(restored_seqs):
|
| 590 |
+
restored_seqs = [f"{h}|{l}".replace('-','') for h,l in zip(restored_seqs[:n_seqs], restored_seqs[n_seqs:])]
|
| 591 |
+
seqs = [f"{h}|{l}" for h,l in zip(seqs[:n_seqs], seqs[n_seqs:])]
|
| 592 |
+
|
| 593 |
+
# Apply final formatting
|
| 594 |
+
from extra_utils import res_to_seq
|
| 595 |
+
return np.array([res_to_seq(seq, 'restore') for seq in np.c_[restored_seqs, np.vectorize(len)(seqs)]])
|
| 596 |
|
| 597 |
def add_angle_brackets(seq):
|
| 598 |
# Assumes input is 'VH|VL' or 'VH|' or '|VL'
|