AnthonyDi commited on
Commit
c402548
·
verified ·
1 Parent(s): 03bd84e

Upload tokenizer.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. tokenizer.py +59 -3
tokenizer.py CHANGED
@@ -4,8 +4,8 @@ from transformers import PreTrainedTokenizer
4
 
5
  class CharacterTokenizer(PreTrainedTokenizer):
6
  """
7
- Character-level tokenizer for OCR tasks.
8
- Each character becomes a separate token.
9
  """
10
 
11
  def __init__(
@@ -75,7 +75,7 @@ class CharacterTokenizer(PreTrainedTokenizer):
75
 
76
  # Remove vocab_file from kwargs if it exists to avoid duplicate argument
77
  kwargs.pop('vocab_file', None)
78
-
79
  return cls(vocab_file=vocab_file, **kwargs)
80
 
81
  @property
@@ -86,15 +86,68 @@ class CharacterTokenizer(PreTrainedTokenizer):
86
  return self.token_to_id
87
 
88
  def _tokenize(self, text):
 
89
  return list(text)
90
 
91
  def _convert_token_to_id(self, token):
 
92
  return self.token_to_id.get(token, self.unk_token_id)
93
 
94
  def _convert_id_to_token(self, index):
 
95
  return self.id_to_token.get(index, self.unk_token)
96
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
  def save_vocabulary(self, save_directory, filename_prefix=None):
 
98
  os.makedirs(save_directory, exist_ok=True)
99
 
100
  vocab_path = os.path.join(save_directory, "vocab.json")
@@ -113,11 +166,13 @@ class CharacterTokenizer(PreTrainedTokenizer):
113
  "unk_token": self.unk_token,
114
  "pad_token": self.pad_token,
115
  "vocab_file": "vocab.json",
 
116
  }, f, indent=2)
117
 
118
  return (vocab_path,)
119
 
120
  def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
 
121
  if token_ids_1 is None:
122
  return [self.bos_token_id] + token_ids_0 + [self.eos_token_id]
123
  else:
@@ -130,6 +185,7 @@ class CharacterTokenizer(PreTrainedTokenizer):
130
  )
131
 
132
  def create_token_type_ids_from_sequences(self, token_ids_0, token_ids_1=None):
 
133
  return [0] * len(
134
  self.build_inputs_with_special_tokens(token_ids_0, token_ids_1)
135
  )
 
4
 
5
  class CharacterTokenizer(PreTrainedTokenizer):
6
  """
7
+ Character-level tokenizer for OCR tasks that follows HuggingFace conventions.
8
+ Each character becomes a separate token, but decoding produces continuous text.
9
  """
10
 
11
  def __init__(
 
75
 
76
  # Remove vocab_file from kwargs if it exists to avoid duplicate argument
77
  kwargs.pop('vocab_file', None)
78
+
79
  return cls(vocab_file=vocab_file, **kwargs)
80
 
81
  @property
 
86
  return self.token_to_id
87
 
88
  def _tokenize(self, text):
89
+ """Tokenize text into individual characters"""
90
  return list(text)
91
 
92
  def _convert_token_to_id(self, token):
93
+ """Convert a token (character) to its ID"""
94
  return self.token_to_id.get(token, self.unk_token_id)
95
 
96
  def _convert_id_to_token(self, index):
97
+ """Convert an ID to its token (character)"""
98
  return self.id_to_token.get(index, self.unk_token)
99
 
100
+ def convert_tokens_to_string(self, tokens):
101
+ """
102
+ Convert a sequence of tokens to a single string.
103
+ This is the KEY method that HuggingFace uses for decoding!
104
+ For character-level tokenization, we join without spaces.
105
+ """
106
+ # Filter out special tokens
107
+ filtered_tokens = []
108
+ for token in tokens:
109
+ if token not in {self.pad_token, self.bos_token, self.eos_token, self.unk_token}:
110
+ filtered_tokens.append(token)
111
+
112
+ # Join characters directly without spaces
113
+ return ''.join(filtered_tokens)
114
+
115
+ def decode(self, token_ids, skip_special_tokens=False, clean_up_tokenization_spaces=True, **kwargs):
116
+ """
117
+ Override decode to ensure proper character-level decoding.
118
+ This follows HuggingFace conventions but handles character-level properly.
119
+ """
120
+ # Convert tensor to list if needed
121
+ if hasattr(token_ids, 'tolist'):
122
+ token_ids = token_ids.tolist()
123
+
124
+ # Convert IDs to tokens
125
+ tokens = [self._convert_id_to_token(id) for id in token_ids]
126
+
127
+ # Filter special tokens if requested
128
+ if skip_special_tokens:
129
+ tokens = [token for token in tokens if token not in {
130
+ self.pad_token, self.bos_token, self.eos_token, self.unk_token
131
+ }]
132
+
133
+ # Use our convert_tokens_to_string method
134
+ text = self.convert_tokens_to_string(tokens)
135
+
136
+ # For character-level, we don't want clean_up_tokenization_spaces
137
+ # since we're not using word-level tokenization
138
+ return text
139
+
140
+ def batch_decode(self, sequences, skip_special_tokens=False, clean_up_tokenization_spaces=True, **kwargs):
141
+ """
142
+ Batch decode following HuggingFace conventions
143
+ """
144
+ return [
145
+ self.decode(seq, skip_special_tokens=skip_special_tokens, clean_up_tokenization_spaces=clean_up_tokenization_spaces, **kwargs)
146
+ for seq in sequences
147
+ ]
148
+
149
  def save_vocabulary(self, save_directory, filename_prefix=None):
150
+ """Save vocabulary following HuggingFace conventions"""
151
  os.makedirs(save_directory, exist_ok=True)
152
 
153
  vocab_path = os.path.join(save_directory, "vocab.json")
 
166
  "unk_token": self.unk_token,
167
  "pad_token": self.pad_token,
168
  "vocab_file": "vocab.json",
169
+ "clean_up_tokenization_spaces": False, # Important for character-level
170
  }, f, indent=2)
171
 
172
  return (vocab_path,)
173
 
174
  def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
175
+ """Build inputs with special tokens following HuggingFace conventions"""
176
  if token_ids_1 is None:
177
  return [self.bos_token_id] + token_ids_0 + [self.eos_token_id]
178
  else:
 
185
  )
186
 
187
  def create_token_type_ids_from_sequences(self, token_ids_0, token_ids_1=None):
188
+ """Create token type IDs following HuggingFace conventions"""
189
  return [0] * len(
190
  self.build_inputs_with_special_tokens(token_ids_0, token_ids_1)
191
  )