Add detailed debug prints to restore method
Browse files- adapter.py +16 -4
adapter.py
CHANGED
|
@@ -572,18 +572,31 @@ class HFAbRestore(AbRestore):
|
|
| 572 |
print("WARNING: Alignment not implemented, skipping...")
|
| 573 |
pass
|
| 574 |
|
|
|
|
|
|
|
| 575 |
# Tokenize sequences using original interface
|
| 576 |
tokens = self.tokenizer(seqs, pad=True, w_extra_tkns=False, device=self.used_device)
|
|
|
|
|
|
|
| 577 |
|
| 578 |
# Get predictions for amino acids (indices 1-20)
|
| 579 |
predictions = self.AbLang(tokens)[:,:,1:21]
|
|
|
|
| 580 |
|
| 581 |
# Find predicted tokens and replace mask tokens
|
| 582 |
predicted_tokens = torch.max(predictions, -1).indices + 1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 583 |
restored_tokens = torch.where(tokens==23, predicted_tokens, tokens)
|
|
|
|
| 584 |
|
| 585 |
# Decode back to sequences using original tokenizer
|
| 586 |
restored_seqs = self.tokenizer(restored_tokens, mode="decode")
|
|
|
|
| 587 |
|
| 588 |
# Handle paired sequences format
|
| 589 |
if n_seqs < len(restored_seqs):
|
|
@@ -592,7 +605,9 @@ class HFAbRestore(AbRestore):
|
|
| 592 |
|
| 593 |
# Apply final formatting
|
| 594 |
from extra_utils import res_to_seq
|
| 595 |
-
|
|
|
|
|
|
|
| 596 |
|
| 597 |
def add_angle_brackets(seq):
|
| 598 |
# Assumes input is 'VH|VL' or 'VH|' or '|VL'
|
|
@@ -939,16 +954,13 @@ class AbLang2PairedHuggingFaceAdapter(AbEncoding, AbRestore, AbAlignment, AbScor
|
|
| 939 |
return [res_to_list(state, seq) for state, seq in zip(probs, formatted_seqs)]
|
| 940 |
|
| 941 |
def restore(self, seqs, align=False, **kwargs):
|
| 942 |
-
print(f"DEBUG: Input sequences: {seqs}")
|
| 943 |
hf_abrestore = HFAbRestore(self.AbLang, self.tokenizer, spread=self.spread, device=self.used_device, ncpu=self.ncpu)
|
| 944 |
restored = hf_abrestore.restore(seqs, align=align)
|
| 945 |
-
print(f"DEBUG: Restored before formatting: {restored}")
|
| 946 |
# Apply angle brackets formatting
|
| 947 |
if isinstance(restored, np.ndarray):
|
| 948 |
restored = np.array([add_angle_brackets(seq) for seq in restored])
|
| 949 |
else:
|
| 950 |
restored = [add_angle_brackets(seq) for seq in restored]
|
| 951 |
-
print(f"DEBUG: Final output: {restored}")
|
| 952 |
return restored
|
| 953 |
|
| 954 |
def extract_input_ids(tokens, device):
|
|
|
|
| 572 |
print("WARNING: Alignment not implemented, skipping...")
|
| 573 |
pass
|
| 574 |
|
| 575 |
+
print(f"DEBUG: Processing sequences: {seqs}")
|
| 576 |
+
|
| 577 |
# Tokenize sequences using original interface
|
| 578 |
tokens = self.tokenizer(seqs, pad=True, w_extra_tkns=False, device=self.used_device)
|
| 579 |
+
print(f"DEBUG: Tokenized shape: {tokens.shape}")
|
| 580 |
+
print(f"DEBUG: First sequence tokens: {tokens[0]}")
|
| 581 |
|
| 582 |
# Get predictions for amino acids (indices 1-20)
|
| 583 |
predictions = self.AbLang(tokens)[:,:,1:21]
|
| 584 |
+
print(f"DEBUG: Predictions shape: {predictions.shape}")
|
| 585 |
|
| 586 |
# Find predicted tokens and replace mask tokens
|
| 587 |
predicted_tokens = torch.max(predictions, -1).indices + 1
|
| 588 |
+
print(f"DEBUG: Predicted tokens: {predicted_tokens[0]}")
|
| 589 |
+
|
| 590 |
+
# Find mask token positions
|
| 591 |
+
mask_positions = (tokens == 23).nonzero(as_tuple=True)
|
| 592 |
+
print(f"DEBUG: Mask token positions: {mask_positions}")
|
| 593 |
+
|
| 594 |
restored_tokens = torch.where(tokens==23, predicted_tokens, tokens)
|
| 595 |
+
print(f"DEBUG: Restored tokens: {restored_tokens[0]}")
|
| 596 |
|
| 597 |
# Decode back to sequences using original tokenizer
|
| 598 |
restored_seqs = self.tokenizer(restored_tokens, mode="decode")
|
| 599 |
+
print(f"DEBUG: Decoded sequences: {restored_seqs}")
|
| 600 |
|
| 601 |
# Handle paired sequences format
|
| 602 |
if n_seqs < len(restored_seqs):
|
|
|
|
| 605 |
|
| 606 |
# Apply final formatting
|
| 607 |
from extra_utils import res_to_seq
|
| 608 |
+
result = np.array([res_to_seq(seq, 'restore') for seq in np.c_[restored_seqs, np.vectorize(len)(seqs)]])
|
| 609 |
+
print(f"DEBUG: Final result: {result}")
|
| 610 |
+
return result
|
| 611 |
|
| 612 |
def add_angle_brackets(seq):
|
| 613 |
# Assumes input is 'VH|VL' or 'VH|' or '|VL'
|
|
|
|
| 954 |
return [res_to_list(state, seq) for state, seq in zip(probs, formatted_seqs)]
|
| 955 |
|
| 956 |
def restore(self, seqs, align=False, **kwargs):
|
|
|
|
| 957 |
hf_abrestore = HFAbRestore(self.AbLang, self.tokenizer, spread=self.spread, device=self.used_device, ncpu=self.ncpu)
|
| 958 |
restored = hf_abrestore.restore(seqs, align=align)
|
|
|
|
| 959 |
# Apply angle brackets formatting
|
| 960 |
if isinstance(restored, np.ndarray):
|
| 961 |
restored = np.array([add_angle_brackets(seq) for seq in restored])
|
| 962 |
else:
|
| 963 |
restored = [add_angle_brackets(seq) for seq in restored]
|
|
|
|
| 964 |
return restored
|
| 965 |
|
| 966 |
def extract_input_ids(tokens, device):
|