Update tokenizer_script.py
Browse files- tokenizer_script.py +5 -5
tokenizer_script.py
CHANGED
|
@@ -76,19 +76,19 @@ class CharacterTokenizer(PreTrainedTokenizer):
|
|
| 76 |
|
| 77 |
return (vocab_file,)
|
| 78 |
|
| 79 |
-
def batch_encode(
|
| 80 |
-
encoded_texts = [
|
| 81 |
# Handle max_length (truncation)
|
| 82 |
if max_length is not None:
|
| 83 |
encoded_texts = [ids[:max_length] for ids in encoded_texts]
|
| 84 |
if add_special_tokens:
|
| 85 |
-
bos_token_id =
|
| 86 |
-
eos_token_id =
|
| 87 |
encoded_texts = [[bos_token_id] + ids + [eos_token_id] for ids in encoded_texts]
|
| 88 |
# Handle padding
|
| 89 |
if padding:
|
| 90 |
# properly handle padding side
|
| 91 |
-
pad_id =
|
| 92 |
max_len = max(len(ids) for ids in encoded_texts) if max_length is None else max_length
|
| 93 |
if tokenizer.padding_side == "right":
|
| 94 |
encoded_texts = [ids + [pad_id] * (max_len - len(ids)) for ids in encoded_texts]
|
|
|
|
| 76 |
|
| 77 |
return (vocab_file,)
|
| 78 |
|
| 79 |
+
def batch_encode(self, texts, add_special_tokens=False, padding=False, truncation=True, max_length=None):
|
| 80 |
+
encoded_texts = [self.encode(text) for text in texts]
|
| 81 |
# Handle max_length (truncation)
|
| 82 |
if max_length is not None:
|
| 83 |
encoded_texts = [ids[:max_length] for ids in encoded_texts]
|
| 84 |
if add_special_tokens:
|
| 85 |
+
bos_token_id = self.convert_tokens_to_ids(tokenizer.bos_token)
|
| 86 |
+
eos_token_id = self.convert_tokens_to_ids(tokenizer.eos_token)
|
| 87 |
encoded_texts = [[bos_token_id] + ids + [eos_token_id] for ids in encoded_texts]
|
| 88 |
# Handle padding
|
| 89 |
if padding:
|
| 90 |
# properly handle padding side
|
| 91 |
+
pad_id = self.vocab.get(tokenizer.pad_token, 0)
|
| 92 |
max_len = max(len(ids) for ids in encoded_texts) if max_length is None else max_length
|
| 93 |
if tokenizer.padding_side == "right":
|
| 94 |
encoded_texts = [ids + [pad_id] * (max_len - len(ids)) for ids in encoded_texts]
|