hemantn commited on
Commit
de33042
·
1 Parent(s): 7ca5e33

Add basic alignment support to restore method

Browse files
Files changed (1) hide show
  1. adapter.py +27 -2
adapter.py CHANGED
@@ -107,8 +107,14 @@ class AbAlignment:
107
  return np.concatenate([aligned_encodings], axis=0)
108
 
109
  def reformat_subsets(self, subset_list, mode='seqcoding', align=False, numbered_seqs=None, seqs=None, number_alignment=None):
110
- if mode in ['seqcoding', 'restore', 'pseudo_log_likelihood', 'confidence']:
111
  return np.concatenate(subset_list)
 
 
 
 
 
 
112
  elif align:
113
  aligned_subsets = []
114
  for num, subset in enumerate(subset_list):
@@ -594,7 +600,26 @@ class HFAbRestore(AbRestore):
594
  restored_seq = restored_seq.replace(' ', '')
595
  restored_seqs.append(restored_seq)
596
 
597
- return np.array(restored_seqs) if len(restored_seqs) > 1 else restored_seqs[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
598
 
599
  def add_angle_brackets(seq):
600
  # Assumes input is 'VH|VL' or 'VH|' or '|VL'
 
107
  return np.concatenate([aligned_encodings], axis=0)
108
 
109
  def reformat_subsets(self, subset_list, mode='seqcoding', align=False, numbered_seqs=None, seqs=None, number_alignment=None):
110
+ if mode in ['seqcoding', 'pseudo_log_likelihood', 'confidence']:
111
  return np.concatenate(subset_list)
112
+ elif mode == 'restore' and align:
113
+ # For restore mode with alignment, return the aligned sequences
114
+ return subset_list[0] if len(subset_list) == 1 else subset_list
115
+ elif mode == 'restore' and not align:
116
+ # For restore mode without alignment, return the restored sequences
117
+ return subset_list[0] if len(subset_list) == 1 else subset_list
118
  elif align:
119
  aligned_subsets = []
120
  for num, subset in enumerate(subset_list):
 
600
  restored_seq = restored_seq.replace(' ', '')
601
  restored_seqs.append(restored_seq)
602
 
603
+ # Handle alignment if requested
604
+ if align:
605
+ # Simple alignment: ensure all sequences have the same length by padding
606
+ if len(restored_seqs) > 1:
607
+ # Find the maximum length
608
+ max_len = max(len(seq) for seq in restored_seqs)
609
+ # Pad shorter sequences with the last character
610
+ aligned_seqs = []
611
+ for seq in restored_seqs:
612
+ if len(seq) < max_len:
613
+ # Pad with the last character of the sequence
614
+ padded_seq = seq + seq[-1] * (max_len - len(seq))
615
+ aligned_seqs.append(padded_seq)
616
+ else:
617
+ aligned_seqs.append(seq)
618
+ return np.array(aligned_seqs)
619
+ else:
620
+ return restored_seqs[0]
621
+ else:
622
+ return np.array(restored_seqs) if len(restored_seqs) > 1 else restored_seqs[0]
623
 
624
  def add_angle_brackets(seq):
625
  # Assumes input is 'VH|VL' or 'VH|' or '|VL'