hemantn commited on
Commit
b3cd799
·
1 Parent(s): 2e88277

Add detailed debug prints to restore method

Browse files
Files changed (1) hide show
  1. adapter.py +16 -4
adapter.py CHANGED
@@ -572,18 +572,31 @@ class HFAbRestore(AbRestore):
572
  print("WARNING: Alignment not implemented, skipping...")
573
  pass
574
 
 
 
575
  # Tokenize sequences using original interface
576
  tokens = self.tokenizer(seqs, pad=True, w_extra_tkns=False, device=self.used_device)
 
 
577
 
578
  # Get predictions for amino acids (indices 1-20)
579
  predictions = self.AbLang(tokens)[:,:,1:21]
 
580
 
581
  # Find predicted tokens and replace mask tokens
582
  predicted_tokens = torch.max(predictions, -1).indices + 1
 
 
 
 
 
 
583
  restored_tokens = torch.where(tokens==23, predicted_tokens, tokens)
 
584
 
585
  # Decode back to sequences using original tokenizer
586
  restored_seqs = self.tokenizer(restored_tokens, mode="decode")
 
587
 
588
  # Handle paired sequences format
589
  if n_seqs < len(restored_seqs):
@@ -592,7 +605,9 @@ class HFAbRestore(AbRestore):
592
 
593
  # Apply final formatting
594
  from extra_utils import res_to_seq
595
- return np.array([res_to_seq(seq, 'restore') for seq in np.c_[restored_seqs, np.vectorize(len)(seqs)]])
 
 
596
 
597
  def add_angle_brackets(seq):
598
  # Assumes input is 'VH|VL' or 'VH|' or '|VL'
@@ -939,16 +954,13 @@ class AbLang2PairedHuggingFaceAdapter(AbEncoding, AbRestore, AbAlignment, AbScor
939
  return [res_to_list(state, seq) for state, seq in zip(probs, formatted_seqs)]
940
 
941
  def restore(self, seqs, align=False, **kwargs):
942
- print(f"DEBUG: Input sequences: {seqs}")
943
  hf_abrestore = HFAbRestore(self.AbLang, self.tokenizer, spread=self.spread, device=self.used_device, ncpu=self.ncpu)
944
  restored = hf_abrestore.restore(seqs, align=align)
945
- print(f"DEBUG: Restored before formatting: {restored}")
946
  # Apply angle brackets formatting
947
  if isinstance(restored, np.ndarray):
948
  restored = np.array([add_angle_brackets(seq) for seq in restored])
949
  else:
950
  restored = [add_angle_brackets(seq) for seq in restored]
951
- print(f"DEBUG: Final output: {restored}")
952
  return restored
953
 
954
  def extract_input_ids(tokens, device):
 
572
  print("WARNING: Alignment not implemented, skipping...")
573
  pass
574
 
575
+ print(f"DEBUG: Processing sequences: {seqs}")
576
+
577
  # Tokenize sequences using original interface
578
  tokens = self.tokenizer(seqs, pad=True, w_extra_tkns=False, device=self.used_device)
579
+ print(f"DEBUG: Tokenized shape: {tokens.shape}")
580
+ print(f"DEBUG: First sequence tokens: {tokens[0]}")
581
 
582
  # Get predictions for amino acids (indices 1-20)
583
  predictions = self.AbLang(tokens)[:,:,1:21]
584
+ print(f"DEBUG: Predictions shape: {predictions.shape}")
585
 
586
  # Find predicted tokens and replace mask tokens
587
  predicted_tokens = torch.max(predictions, -1).indices + 1
588
+ print(f"DEBUG: Predicted tokens: {predicted_tokens[0]}")
589
+
590
+ # Find mask token positions
591
+ mask_positions = (tokens == 23).nonzero(as_tuple=True)
592
+ print(f"DEBUG: Mask token positions: {mask_positions}")
593
+
594
  restored_tokens = torch.where(tokens==23, predicted_tokens, tokens)
595
+ print(f"DEBUG: Restored tokens: {restored_tokens[0]}")
596
 
597
  # Decode back to sequences using original tokenizer
598
  restored_seqs = self.tokenizer(restored_tokens, mode="decode")
599
+ print(f"DEBUG: Decoded sequences: {restored_seqs}")
600
 
601
  # Handle paired sequences format
602
  if n_seqs < len(restored_seqs):
 
605
 
606
  # Apply final formatting
607
  from extra_utils import res_to_seq
608
+ result = np.array([res_to_seq(seq, 'restore') for seq in np.c_[restored_seqs, np.vectorize(len)(seqs)]])
609
+ print(f"DEBUG: Final result: {result}")
610
+ return result
611
 
612
  def add_angle_brackets(seq):
613
  # Assumes input is 'VH|VL' or 'VH|' or '|VL'
 
954
  return [res_to_list(state, seq) for state, seq in zip(probs, formatted_seqs)]
955
 
956
  def restore(self, seqs, align=False, **kwargs):
 
957
  hf_abrestore = HFAbRestore(self.AbLang, self.tokenizer, spread=self.spread, device=self.used_device, ncpu=self.ncpu)
958
  restored = hf_abrestore.restore(seqs, align=align)
 
959
  # Apply angle brackets formatting
960
  if isinstance(restored, np.ndarray):
961
  restored = np.array([add_angle_brackets(seq) for seq in restored])
962
  else:
963
  restored = [add_angle_brackets(seq) for seq in restored]
 
964
  return restored
965
 
966
  def extract_input_ids(tokens, device):