Add angle brackets back to restore output to match original format
Browse files- adapter.py +5 -0
adapter.py
CHANGED
|
@@ -1023,6 +1023,11 @@ class AbLang2PairedHuggingFaceAdapter(AbEncoding, AbRestore, AbAlignment, AbScor
|
|
| 1023 |
def restore(self, seqs, align=False, **kwargs):
|
| 1024 |
hf_abrestore = HFAbRestore(self.AbLang, self.tokenizer, spread=self.spread, device=self.used_device, ncpu=self.ncpu)
|
| 1025 |
restored = hf_abrestore.restore(seqs, align=align)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1026 |
return restored
|
| 1027 |
|
| 1028 |
def extract_input_ids(tokens, device):
|
|
|
|
| 1023 |
def restore(self, seqs, align=False, **kwargs):
|
| 1024 |
hf_abrestore = HFAbRestore(self.AbLang, self.tokenizer, spread=self.spread, device=self.used_device, ncpu=self.ncpu)
|
| 1025 |
restored = hf_abrestore.restore(seqs, align=align)
|
| 1026 |
+
# Apply angle brackets formatting to match original format
|
| 1027 |
+
if isinstance(restored, np.ndarray):
|
| 1028 |
+
restored = np.array([add_angle_brackets(seq) for seq in restored])
|
| 1029 |
+
else:
|
| 1030 |
+
restored = [add_angle_brackets(seq) for seq in restored]
|
| 1031 |
return restored
|
| 1032 |
|
| 1033 |
def extract_input_ids(tokens, device):
|