| import json |
| import os |
| from transformers import PreTrainedTokenizer |
|
|
| class CharacterTokenizer(PreTrainedTokenizer): |
| """ |
| Character-level tokenizer for OCR tasks. |
| Each character becomes a separate token. |
| """ |
|
|
| def __init__( |
| self, |
| vocab_file=None, |
| unk_token="<unk>", |
| pad_token="<pad>", |
| bos_token="<s>", |
| eos_token="</s>", |
| max_length=256, |
| **kwargs |
| ): |
| if vocab_file is None or not os.path.isfile(vocab_file): |
| raise ValueError("`vocab_file` must be provided or exist.") |
|
|
| |
| with open(vocab_file, "r", encoding="utf-8") as f: |
| self.token_to_id = json.load(f) |
| self.id_to_token = {v: k for k, v in self.token_to_id.items()} |
|
|
| self.max_length = max_length |
|
|
| |
| super().__init__( |
| unk_token=unk_token, |
| pad_token=pad_token, |
| bos_token=bos_token, |
| eos_token=eos_token, |
| **kwargs |
| ) |
|
|
| @classmethod |
| def register_for_auto_class(cls, auto_class="AutoTokenizer"): |
| """Register this tokenizer for AutoTokenizer""" |
| return cls |
|
|
| @property |
| def vocab_size(self): |
| return len(self.token_to_id) |
|
|
| def get_vocab(self): |
| return self.token_to_id |
|
|
| def _tokenize(self, text): |
| return list(text) |
|
|
| def _convert_token_to_id(self, token): |
| return self.token_to_id.get(token, self.unk_token_id) |
|
|
| def _convert_id_to_token(self, index): |
| return self.id_to_token.get(index, self.unk_token) |
|
|
| def save_vocabulary(self, save_directory, filename_prefix=None): |
| os.makedirs(save_directory, exist_ok=True) |
|
|
| vocab_path = os.path.join(save_directory, "vocab.json") |
| with open(vocab_path, "w", encoding="utf-8") as f: |
| json.dump(self.token_to_id, f, ensure_ascii=False, indent=2) |
|
|
| config_path = os.path.join(save_directory, "tokenizer_config.json") |
| with open(config_path, "w", encoding="utf-8") as f: |
| json.dump({ |
| "tokenizer_class": "CharacterTokenizer", |
| "auto_map": { |
| "AutoTokenizer": ["tokenizer.CharacterTokenizer", None] |
| }, |
| "bos_token": self.bos_token, |
| "eos_token": self.eos_token, |
| "unk_token": self.unk_token, |
| "pad_token": self.pad_token, |
| "vocab_file": "vocab.json", |
| }, f, indent=2) |
|
|
| return (vocab_path,) |
|
|
| def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): |
| if token_ids_1 is None: |
| return [self.bos_token_id] + token_ids_0 + [self.eos_token_id] |
| else: |
| return ( |
| [self.bos_token_id] |
| + token_ids_0 |
| + [self.eos_token_id] |
| + token_ids_1 |
| + [self.eos_token_id] |
| ) |
|
|
| def create_token_type_ids_from_sequences(self, token_ids_0, token_ids_1=None): |
| return [0] * len( |
| self.build_inputs_with_special_tokens(token_ids_0, token_ids_1) |
| ) |
|
|