Remove spaces from decoded sequences in restore method
Browse files- adapter.py +3 -1
adapter.py
CHANGED
|
@@ -588,8 +588,10 @@ class HFAbRestore(AbRestore):
|
|
| 588 |
restored_input_ids = input_ids[0].clone()
|
| 589 |
restored_input_ids[masked_positions] = predicted_tokens
|
| 590 |
|
| 591 |
-
# Decode back to sequence using the original tokenizer
|
| 592 |
restored_seq = self.tokenizer.tokenizer.decode(restored_input_ids, skip_special_tokens=True)
|
|
|
|
|
|
|
| 593 |
restored_seqs.append(restored_seq)
|
| 594 |
|
| 595 |
return np.array(restored_seqs) if len(restored_seqs) > 1 else restored_seqs[0]
|
|
|
|
| 588 |
restored_input_ids = input_ids[0].clone()
|
| 589 |
restored_input_ids[masked_positions] = predicted_tokens
|
| 590 |
|
| 591 |
+
# Decode back to sequence using the original tokenizer and remove spaces
|
| 592 |
restored_seq = self.tokenizer.tokenizer.decode(restored_input_ids, skip_special_tokens=True)
|
| 593 |
+
# Remove spaces that might be added during decoding
|
| 594 |
+
restored_seq = restored_seq.replace(' ', '')
|
| 595 |
restored_seqs.append(restored_seq)
|
| 596 |
|
| 597 |
return np.array(restored_seqs) if len(restored_seqs) > 1 else restored_seqs[0]
|