Fix tokenizer and format_seq_input to properly handle paired sequences with angle brackets
Browse files- adapter.py +24 -22
- tokenizer_ablang2paired.py +8 -1
adapter.py
CHANGED
|
@@ -215,31 +215,33 @@ class AbLang2PairedHuggingFaceAdapter(AbEncoding, AbRestore, AbAlignment, AbScor
|
|
| 215 |
# Local implementation of format_seq_input
|
| 216 |
def format_seq_input(seqs, fragmented=False):
|
| 217 |
"""Format input sequences for processing."""
|
|
|
|
|
|
|
|
|
|
| 218 |
if fragmented:
|
| 219 |
-
# For fragmented sequences,
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
formatted_seqs = []
|
| 224 |
-
for seq in seqs:
|
| 225 |
-
if isinstance(seq, (list, tuple)):
|
| 226 |
-
if len(seq) == 2:
|
| 227 |
-
# Heavy and light chain
|
| 228 |
heavy, light = seq[0], seq[1]
|
| 229 |
-
|
| 230 |
-
formatted_seqs.append(f"{heavy}|{light}")
|
| 231 |
-
elif heavy:
|
| 232 |
-
formatted_seqs.append(heavy)
|
| 233 |
-
elif light:
|
| 234 |
-
formatted_seqs.append(light)
|
| 235 |
-
else:
|
| 236 |
-
formatted_seqs.append("")
|
| 237 |
else:
|
| 238 |
-
formatted_seqs.append(seq
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 243 |
|
| 244 |
valid_modes = [
|
| 245 |
'rescoding', 'seqcoding', 'restore', 'likelihood', 'probability',
|
|
|
|
| 215 |
# Local implementation of format_seq_input
|
| 216 |
def format_seq_input(seqs, fragmented=False):
|
| 217 |
"""Format input sequences for processing."""
|
| 218 |
+
if isinstance(seqs[0], str):
|
| 219 |
+
seqs = [seqs]
|
| 220 |
+
|
| 221 |
if fragmented:
|
| 222 |
+
# For fragmented sequences, format as VH|VL without angle brackets
|
| 223 |
+
formatted_seqs = []
|
| 224 |
+
for seq in seqs:
|
| 225 |
+
if isinstance(seq, (list, tuple)) and len(seq) == 2:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 226 |
heavy, light = seq[0], seq[1]
|
| 227 |
+
formatted_seqs.append(f"{heavy}|{light}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 228 |
else:
|
| 229 |
+
formatted_seqs.append(seq)
|
| 230 |
+
return formatted_seqs, 'HL'
|
| 231 |
+
else:
|
| 232 |
+
# For non-fragmented sequences, add angle brackets: <VH>|<VL>
|
| 233 |
+
formatted_seqs = []
|
| 234 |
+
for seq in seqs:
|
| 235 |
+
if isinstance(seq, (list, tuple)) and len(seq) == 2:
|
| 236 |
+
heavy, light = seq[0], seq[1]
|
| 237 |
+
# Add angle brackets and handle empty sequences
|
| 238 |
+
heavy_part = f"<{heavy}>" if heavy else "<>"
|
| 239 |
+
light_part = f"<{light}>" if light else "<>"
|
| 240 |
+
formatted_seqs.append(f"{heavy_part}|{light_part}".replace("<>", ""))
|
| 241 |
+
else:
|
| 242 |
+
formatted_seqs.append(seq)
|
| 243 |
+
|
| 244 |
+
return formatted_seqs, 'HL'
|
| 245 |
|
| 246 |
valid_modes = [
|
| 247 |
'rescoding', 'seqcoding', 'restore', 'likelihood', 'probability',
|
tokenizer_ablang2paired.py
CHANGED
|
@@ -100,9 +100,16 @@ class AbLang2PairedTokenizer(PreTrainedTokenizer):
|
|
| 100 |
return vocab_files
|
| 101 |
|
| 102 |
def __call__(self, sequences, padding=False, return_tensors=None, **kwargs):
|
| 103 |
-
#
|
| 104 |
if isinstance(sequences, str):
|
|
|
|
| 105 |
sequences = [sequences]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 106 |
# Tokenize each sequence
|
| 107 |
input_ids = [[self._convert_token_to_id(tok) for tok in self._tokenize(seq)] for seq in sequences]
|
| 108 |
# Padding
|
|
|
|
| 100 |
return vocab_files
|
| 101 |
|
| 102 |
def __call__(self, sequences, padding=False, return_tensors=None, **kwargs):
|
| 103 |
+
# Handle different input formats
|
| 104 |
if isinstance(sequences, str):
|
| 105 |
+
# Single string: "VH|VL"
|
| 106 |
sequences = [sequences]
|
| 107 |
+
elif isinstance(sequences, list) and len(sequences) > 0:
|
| 108 |
+
if isinstance(sequences[0], list):
|
| 109 |
+
# List of lists: [['VH', 'VL'], ['VH2', 'VL2']]
|
| 110 |
+
sequences = [f"{pair[0]}|{pair[1]}" for pair in sequences]
|
| 111 |
+
# List of strings: ["VH|VL", "VH2|VL2"] - already correct format
|
| 112 |
+
|
| 113 |
# Tokenize each sequence
|
| 114 |
input_ids = [[self._convert_token_to_id(tok) for tok in self._tokenize(seq)] for seq in sequences]
|
| 115 |
# Padding
|