hemantn commited on
Commit
f9eb722
·
1 Parent(s): acf4b07

Add debug prints to restore method

Browse files
Files changed (1) hide show
  1. adapter.py +10 -0
adapter.py CHANGED
@@ -572,15 +572,23 @@ class HFAbRestore(AbRestore):
572
  restored_seqs.append(seq)
573
  continue
574
 
 
 
 
 
575
  # Tokenize the sequence
576
  input_ids = self.tokenizer([seq], pad=True, w_extra_tkns=False, device=self.used_device)
 
 
577
 
578
  # Find masked tokens (assuming * is the mask token)
579
  mask_token_id = self.tokenizer.mask_token_id
580
  masked_positions = (input_ids[0] == mask_token_id).nonzero(as_tuple=True)[0]
 
581
 
582
  if len(masked_positions) == 0:
583
  # No masked tokens found, return original
 
584
  restored_seqs.append(seq)
585
  continue
586
 
@@ -595,6 +603,7 @@ class HFAbRestore(AbRestore):
595
  # Get predictions for masked positions
596
  masked_logits = logits[0, masked_positions]
597
  predicted_tokens = torch.argmax(masked_logits, dim=-1)
 
598
 
599
  # Replace masked tokens with predicted tokens
600
  restored_input_ids = input_ids[0].clone()
@@ -603,6 +612,7 @@ class HFAbRestore(AbRestore):
603
  # Decode back to sequence
604
  restored_seq = self.tokenizer.tokenizer.decode(restored_input_ids, skip_special_tokens=True)
605
  restored_seq = restored_seq.replace(' ', '')
 
606
  restored_seqs.append(restored_seq)
607
 
608
  return np.array(restored_seqs) if len(restored_seqs) > 1 else restored_seqs[0]
 
572
  restored_seqs.append(seq)
573
  continue
574
 
575
+ # Debug: Print the sequence and mask token info
576
+ print(f"Processing sequence: {seq}")
577
+ print(f"Mask token ID: {self.tokenizer.mask_token_id}")
578
+
579
  # Tokenize the sequence
580
  input_ids = self.tokenizer([seq], pad=True, w_extra_tkns=False, device=self.used_device)
581
+ print(f"Input IDs shape: {input_ids.shape}")
582
+ print(f"Input IDs: {input_ids[0]}")
583
 
584
  # Find masked tokens (assuming * is the mask token)
585
  mask_token_id = self.tokenizer.mask_token_id
586
  masked_positions = (input_ids[0] == mask_token_id).nonzero(as_tuple=True)[0]
587
+ print(f"Masked positions: {masked_positions}")
588
 
589
  if len(masked_positions) == 0:
590
  # No masked tokens found, return original
591
+ print("No masked tokens found, returning original")
592
  restored_seqs.append(seq)
593
  continue
594
 
 
603
  # Get predictions for masked positions
604
  masked_logits = logits[0, masked_positions]
605
  predicted_tokens = torch.argmax(masked_logits, dim=-1)
606
+ print(f"Predicted tokens: {predicted_tokens}")
607
 
608
  # Replace masked tokens with predicted tokens
609
  restored_input_ids = input_ids[0].clone()
 
612
  # Decode back to sequence
613
  restored_seq = self.tokenizer.tokenizer.decode(restored_input_ids, skip_special_tokens=True)
614
  restored_seq = restored_seq.replace(' ', '')
615
+ print(f"Restored sequence: {restored_seq}")
616
  restored_seqs.append(restored_seq)
617
 
618
  return np.array(restored_seqs) if len(restored_seqs) > 1 else restored_seqs[0]