Remove HFAbRestore.restore override to use original AbRestore.restore method
Browse files- adapter.py +0 -29
adapter.py
CHANGED
|
@@ -558,35 +558,6 @@ class HFAbRestore(AbRestore):
|
|
| 558 |
return output.last_hidden_state
|
| 559 |
return output
|
| 560 |
return model_call
|
| 561 |
-
|
| 562 |
-
def restore(self, seqs, align=False, **kwargs):
|
| 563 |
-
"""Restore masked residues in antibody sequences."""
|
| 564 |
-
if isinstance(seqs, str):
|
| 565 |
-
seqs = [seqs]
|
| 566 |
-
|
| 567 |
-
# Use the original AbLang2 restore logic
|
| 568 |
-
# Tokenize sequences using original interface
|
| 569 |
-
tokens = self.tokenizer(seqs, pad=True, w_extra_tkns=False, device=self.used_device)
|
| 570 |
-
|
| 571 |
-
# Get predictions for amino acids (indices 1-20)
|
| 572 |
-
predictions = self.AbLang(tokens)[:,:,1:21]
|
| 573 |
-
|
| 574 |
-
# Find predicted tokens and replace mask tokens
|
| 575 |
-
predicted_tokens = torch.max(predictions, -1).indices + 1
|
| 576 |
-
restored_tokens = torch.where(tokens==23, predicted_tokens, tokens)
|
| 577 |
-
|
| 578 |
-
# Decode back to sequences using original tokenizer
|
| 579 |
-
restored_seqs = self.tokenizer(restored_tokens, mode="decode")
|
| 580 |
-
|
| 581 |
-
# Handle paired sequences format
|
| 582 |
-
n_seqs = len(seqs)
|
| 583 |
-
if n_seqs < len(restored_seqs):
|
| 584 |
-
restored_seqs = [f"{h}|{l}".replace('-','') for h,l in zip(restored_seqs[:n_seqs], restored_seqs[n_seqs:])]
|
| 585 |
-
seqs = [f"{h}|{l}" for h,l in zip(seqs[:n_seqs], seqs[n_seqs:])]
|
| 586 |
-
|
| 587 |
-
# Apply final formatting
|
| 588 |
-
from extra_utils import res_to_seq
|
| 589 |
-
return np.array([res_to_seq(seq, 'restore') for seq in np.c_[restored_seqs, np.vectorize(len)(seqs)]])
|
| 590 |
|
| 591 |
def add_angle_brackets(seq):
|
| 592 |
# Assumes input is 'VH|VL' or 'VH|' or '|VL'
|
|
|
|
| 558 |
return output.last_hidden_state
|
| 559 |
return output
|
| 560 |
return model_call
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 561 |
|
| 562 |
def add_angle_brackets(seq):
|
| 563 |
# Assumes input is 'VH|VL' or 'VH|' or '|VL'
|