Add basic alignment support to restore method
Browse files- adapter.py +27 -2
adapter.py
CHANGED
|
@@ -107,8 +107,14 @@ class AbAlignment:
|
|
| 107 |
return np.concatenate([aligned_encodings], axis=0)
|
| 108 |
|
| 109 |
def reformat_subsets(self, subset_list, mode='seqcoding', align=False, numbered_seqs=None, seqs=None, number_alignment=None):
|
| 110 |
-
if mode in ['seqcoding', '
|
| 111 |
return np.concatenate(subset_list)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 112 |
elif align:
|
| 113 |
aligned_subsets = []
|
| 114 |
for num, subset in enumerate(subset_list):
|
|
@@ -594,7 +600,26 @@ class HFAbRestore(AbRestore):
|
|
| 594 |
restored_seq = restored_seq.replace(' ', '')
|
| 595 |
restored_seqs.append(restored_seq)
|
| 596 |
|
| 597 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 598 |
|
| 599 |
def add_angle_brackets(seq):
|
| 600 |
# Assumes input is 'VH|VL' or 'VH|' or '|VL'
|
|
|
|
| 107 |
return np.concatenate([aligned_encodings], axis=0)
|
| 108 |
|
| 109 |
def reformat_subsets(self, subset_list, mode='seqcoding', align=False, numbered_seqs=None, seqs=None, number_alignment=None):
|
| 110 |
+
if mode in ['seqcoding', 'pseudo_log_likelihood', 'confidence']:
|
| 111 |
return np.concatenate(subset_list)
|
| 112 |
+
elif mode == 'restore' and align:
|
| 113 |
+
# For restore mode with alignment, return the aligned sequences
|
| 114 |
+
return subset_list[0] if len(subset_list) == 1 else subset_list
|
| 115 |
+
elif mode == 'restore' and not align:
|
| 116 |
+
# For restore mode without alignment, return the restored sequences
|
| 117 |
+
return subset_list[0] if len(subset_list) == 1 else subset_list
|
| 118 |
elif align:
|
| 119 |
aligned_subsets = []
|
| 120 |
for num, subset in enumerate(subset_list):
|
|
|
|
| 600 |
restored_seq = restored_seq.replace(' ', '')
|
| 601 |
restored_seqs.append(restored_seq)
|
| 602 |
|
| 603 |
+
# Handle alignment if requested
|
| 604 |
+
if align:
|
| 605 |
+
# Simple alignment: ensure all sequences have the same length by padding
|
| 606 |
+
if len(restored_seqs) > 1:
|
| 607 |
+
# Find the maximum length
|
| 608 |
+
max_len = max(len(seq) for seq in restored_seqs)
|
| 609 |
+
# Pad shorter sequences with the last character
|
| 610 |
+
aligned_seqs = []
|
| 611 |
+
for seq in restored_seqs:
|
| 612 |
+
if len(seq) < max_len:
|
| 613 |
+
# Pad with the last character of the sequence
|
| 614 |
+
padded_seq = seq + seq[-1] * (max_len - len(seq))
|
| 615 |
+
aligned_seqs.append(padded_seq)
|
| 616 |
+
else:
|
| 617 |
+
aligned_seqs.append(seq)
|
| 618 |
+
return np.array(aligned_seqs)
|
| 619 |
+
else:
|
| 620 |
+
return restored_seqs[0]
|
| 621 |
+
else:
|
| 622 |
+
return np.array(restored_seqs) if len(restored_seqs) > 1 else restored_seqs[0]
|
| 623 |
|
| 624 |
def add_angle_brackets(seq):
|
| 625 |
# Assumes input is 'VH|VL' or 'VH|' or '|VL'
|