Implement proper alignment functionality using ANARCI
Browse files- adapter.py +67 -29
adapter.py
CHANGED
|
@@ -568,46 +568,84 @@ class HFAbRestore(AbRestore):
|
|
| 568 |
n_seqs = len(seqs)
|
| 569 |
|
| 570 |
if align:
|
| 571 |
-
#
|
| 572 |
-
print("
|
| 573 |
-
|
| 574 |
-
|
| 575 |
-
|
| 576 |
-
|
| 577 |
-
|
| 578 |
-
|
| 579 |
-
|
| 580 |
-
|
| 581 |
-
|
| 582 |
-
|
| 583 |
-
|
| 584 |
-
|
| 585 |
-
|
| 586 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 587 |
predicted_tokens = torch.max(predictions, -1).indices + 1
|
| 588 |
-
print(f"DEBUG: Predicted tokens: {predicted_tokens[0]}")
|
| 589 |
-
|
| 590 |
-
# Find mask token positions
|
| 591 |
-
mask_positions = (tokens == 23).nonzero(as_tuple=True)
|
| 592 |
-
print(f"DEBUG: Mask token positions: {mask_positions}")
|
| 593 |
-
|
| 594 |
restored_tokens = torch.where(tokens==23, predicted_tokens, tokens)
|
| 595 |
-
|
| 596 |
-
|
| 597 |
-
# Decode back to sequences using original tokenizer
|
| 598 |
restored_seqs = self.tokenizer(restored_tokens, mode="decode")
|
| 599 |
-
|
| 600 |
-
|
| 601 |
-
# Handle paired sequences format
|
| 602 |
if n_seqs < len(restored_seqs):
|
| 603 |
restored_seqs = [f"{h}|{l}".replace('-','') for h,l in zip(restored_seqs[:n_seqs], restored_seqs[n_seqs:])]
|
| 604 |
seqs = [f"{h}|{l}" for h,l in zip(seqs[:n_seqs], seqs[n_seqs:])]
|
| 605 |
|
| 606 |
-
# Apply final formatting
|
| 607 |
from extra_utils import res_to_seq
|
| 608 |
result = np.array([res_to_seq(seq, 'restore') for seq in np.c_[restored_seqs, np.vectorize(len)(seqs)]])
|
| 609 |
print(f"DEBUG: Final result: {result}")
|
| 610 |
return result
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 611 |
|
| 612 |
def add_angle_brackets(seq):
|
| 613 |
# Assumes input is 'VH|VL' or 'VH|' or '|VL'
|
|
|
|
| 568 |
n_seqs = len(seqs)
|
| 569 |
|
| 570 |
if align:
|
| 571 |
+
# Implement alignment using ANARCI to create spread sequences
|
| 572 |
+
print("DEBUG: Using alignment to create spread sequences...")
|
| 573 |
+
seqs = self._sequence_aligning(seqs)
|
| 574 |
+
nr_seqs = len(seqs)//self.spread
|
| 575 |
+
|
| 576 |
+
tokens = self.tokenizer(seqs, pad=True, w_extra_tkns=False, device=self.used_device)
|
| 577 |
+
predictions = self.AbLang(tokens)[:,:,1:21]
|
| 578 |
+
|
| 579 |
+
# Reshape
|
| 580 |
+
tokens = tokens.reshape(nr_seqs, self.spread, -1)
|
| 581 |
+
predictions = predictions.reshape(nr_seqs, self.spread, -1, 20)
|
| 582 |
+
seqs = seqs.reshape(nr_seqs, -1)
|
| 583 |
+
|
| 584 |
+
# Find index of best predictions
|
| 585 |
+
best_seq_idx = torch.argmax(torch.max(predictions, -1).values[:,:,1:2].mean(2), -1)
|
| 586 |
+
|
| 587 |
+
# Select best predictions
|
| 588 |
+
tokens = tokens.gather(1, best_seq_idx.view(-1, 1).unsqueeze(1).repeat(1, 1, tokens.shape[-1])).squeeze(1)
|
| 589 |
+
predictions = predictions[range(predictions.shape[0]), best_seq_idx]
|
| 590 |
+
seqs = np.take_along_axis(seqs, best_seq_idx.view(-1, 1).cpu().numpy(), axis=1)
|
| 591 |
+
else:
|
| 592 |
+
print(f"DEBUG: Processing sequences without alignment: {seqs}")
|
| 593 |
+
tokens = self.tokenizer(seqs, pad=True, w_extra_tkns=False, device=self.used_device)
|
| 594 |
+
predictions = self.AbLang(tokens)[:,:,1:21]
|
| 595 |
+
|
| 596 |
predicted_tokens = torch.max(predictions, -1).indices + 1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 597 |
restored_tokens = torch.where(tokens==23, predicted_tokens, tokens)
|
| 598 |
+
|
|
|
|
|
|
|
| 599 |
restored_seqs = self.tokenizer(restored_tokens, mode="decode")
|
| 600 |
+
|
|
|
|
|
|
|
| 601 |
if n_seqs < len(restored_seqs):
|
| 602 |
restored_seqs = [f"{h}|{l}".replace('-','') for h,l in zip(restored_seqs[:n_seqs], restored_seqs[n_seqs:])]
|
| 603 |
seqs = [f"{h}|{l}" for h,l in zip(seqs[:n_seqs], seqs[n_seqs:])]
|
| 604 |
|
|
|
|
| 605 |
from extra_utils import res_to_seq
|
| 606 |
result = np.array([res_to_seq(seq, 'restore') for seq in np.c_[restored_seqs, np.vectorize(len)(seqs)]])
|
| 607 |
print(f"DEBUG: Final result: {result}")
|
| 608 |
return result
|
| 609 |
+
|
| 610 |
+
def _sequence_aligning(self, seqs):
|
| 611 |
+
"""Create spread sequences using ANARCI alignment."""
|
| 612 |
+
tmp_seqs = [pairs.replace(">", "").replace("<", "").split("|") for pairs in seqs]
|
| 613 |
+
|
| 614 |
+
spread_heavy = [f"<{seq}>" for seq in self._create_spread_of_sequences(tmp_seqs, chain = 'H')]
|
| 615 |
+
spread_light = [f"<{seq}>" for seq in self._create_spread_of_sequences(tmp_seqs, chain = 'L')]
|
| 616 |
+
|
| 617 |
+
return np.concatenate([np.array(spread_heavy),np.array(spread_light)])
|
| 618 |
+
|
| 619 |
+
def _create_spread_of_sequences(self, seqs, chain = 'H'):
|
| 620 |
+
"""Create spread sequences using ANARCI."""
|
| 621 |
+
import pandas as pd
|
| 622 |
+
import anarci
|
| 623 |
+
|
| 624 |
+
chain_idx = 0 if chain == 'H' else 1
|
| 625 |
+
numbered_seqs = anarci.run_anarci(
|
| 626 |
+
pd.DataFrame([seq[chain_idx].replace('*', 'X') for seq in seqs]).reset_index().values.tolist(),
|
| 627 |
+
ncpu=self.ncpu,
|
| 628 |
+
scheme='imgt',
|
| 629 |
+
allowed_species=['human', 'mouse'],
|
| 630 |
+
)
|
| 631 |
+
|
| 632 |
+
anarci_data = pd.DataFrame(
|
| 633 |
+
[str(anarci[0][0]) if anarci else 'ANARCI_error' for anarci in numbered_seqs[1]],
|
| 634 |
+
columns=['anarci']
|
| 635 |
+
).astype('<U90')
|
| 636 |
+
|
| 637 |
+
max_position = 128 if chain == 'H' else 127
|
| 638 |
+
|
| 639 |
+
from extra_utils import get_sequences_from_anarci
|
| 640 |
+
seqs = anarci_data.apply(
|
| 641 |
+
lambda x: get_sequences_from_anarci(
|
| 642 |
+
x.anarci,
|
| 643 |
+
max_position,
|
| 644 |
+
self.spread
|
| 645 |
+
), axis=1, result_type='expand'
|
| 646 |
+
).to_numpy().reshape(-1)
|
| 647 |
+
|
| 648 |
+
return seqs
|
| 649 |
|
| 650 |
def add_angle_brackets(seq):
|
| 651 |
# Assumes input is 'VH|VL' or 'VH|' or '|VL'
|