Add debug prints to restore method
Browse files- adapter.py +10 -0
adapter.py
CHANGED
|
@@ -572,15 +572,23 @@ class HFAbRestore(AbRestore):
|
|
| 572 |
restored_seqs.append(seq)
|
| 573 |
continue
|
| 574 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 575 |
# Tokenize the sequence
|
| 576 |
input_ids = self.tokenizer([seq], pad=True, w_extra_tkns=False, device=self.used_device)
|
|
|
|
|
|
|
| 577 |
|
| 578 |
# Find masked tokens (assuming * is the mask token)
|
| 579 |
mask_token_id = self.tokenizer.mask_token_id
|
| 580 |
masked_positions = (input_ids[0] == mask_token_id).nonzero(as_tuple=True)[0]
|
|
|
|
| 581 |
|
| 582 |
if len(masked_positions) == 0:
|
| 583 |
# No masked tokens found, return original
|
|
|
|
| 584 |
restored_seqs.append(seq)
|
| 585 |
continue
|
| 586 |
|
|
@@ -595,6 +603,7 @@ class HFAbRestore(AbRestore):
|
|
| 595 |
# Get predictions for masked positions
|
| 596 |
masked_logits = logits[0, masked_positions]
|
| 597 |
predicted_tokens = torch.argmax(masked_logits, dim=-1)
|
|
|
|
| 598 |
|
| 599 |
# Replace masked tokens with predicted tokens
|
| 600 |
restored_input_ids = input_ids[0].clone()
|
|
@@ -603,6 +612,7 @@ class HFAbRestore(AbRestore):
|
|
| 603 |
# Decode back to sequence
|
| 604 |
restored_seq = self.tokenizer.tokenizer.decode(restored_input_ids, skip_special_tokens=True)
|
| 605 |
restored_seq = restored_seq.replace(' ', '')
|
|
|
|
| 606 |
restored_seqs.append(restored_seq)
|
| 607 |
|
| 608 |
return np.array(restored_seqs) if len(restored_seqs) > 1 else restored_seqs[0]
|
|
|
|
| 572 |
restored_seqs.append(seq)
|
| 573 |
continue
|
| 574 |
|
| 575 |
+
# Debug: Print the sequence and mask token info
|
| 576 |
+
print(f"Processing sequence: {seq}")
|
| 577 |
+
print(f"Mask token ID: {self.tokenizer.mask_token_id}")
|
| 578 |
+
|
| 579 |
# Tokenize the sequence
|
| 580 |
input_ids = self.tokenizer([seq], pad=True, w_extra_tkns=False, device=self.used_device)
|
| 581 |
+
print(f"Input IDs shape: {input_ids.shape}")
|
| 582 |
+
print(f"Input IDs: {input_ids[0]}")
|
| 583 |
|
| 584 |
# Find masked tokens (assuming * is the mask token)
|
| 585 |
mask_token_id = self.tokenizer.mask_token_id
|
| 586 |
masked_positions = (input_ids[0] == mask_token_id).nonzero(as_tuple=True)[0]
|
| 587 |
+
print(f"Masked positions: {masked_positions}")
|
| 588 |
|
| 589 |
if len(masked_positions) == 0:
|
| 590 |
# No masked tokens found, return original
|
| 591 |
+
print("No masked tokens found, returning original")
|
| 592 |
restored_seqs.append(seq)
|
| 593 |
continue
|
| 594 |
|
|
|
|
| 603 |
# Get predictions for masked positions
|
| 604 |
masked_logits = logits[0, masked_positions]
|
| 605 |
predicted_tokens = torch.argmax(masked_logits, dim=-1)
|
| 606 |
+
print(f"Predicted tokens: {predicted_tokens}")
|
| 607 |
|
| 608 |
# Replace masked tokens with predicted tokens
|
| 609 |
restored_input_ids = input_ids[0].clone()
|
|
|
|
| 612 |
# Decode back to sequence
|
| 613 |
restored_seq = self.tokenizer.tokenizer.decode(restored_input_ids, skip_special_tokens=True)
|
| 614 |
restored_seq = restored_seq.replace(' ', '')
|
| 615 |
+
print(f"Restored sequence: {restored_seq}")
|
| 616 |
restored_seqs.append(restored_seq)
|
| 617 |
|
| 618 |
return np.array(restored_seqs) if len(restored_seqs) > 1 else restored_seqs[0]
|