WildnerveAI commited on
Commit
53d82e6
·
verified ·
1 Parent(s): c8e5b57

Upload 2 files

Browse files
Files changed (1) hide show
  1. tokenizer.py +223 -6
tokenizer.py CHANGED
@@ -1,14 +1,231 @@
 
 
 
1
  import os
2
- import json
3
  import logging
4
- from typing import List, Dict, Optional, Union, Any
 
5
 
6
  logger = logging.getLogger(__name__)
7
 
 
 
 
8
  class TokenizerWrapper:
9
- """Lightweight wrapper around GPT-2 tokenizer with memory optimization"""
10
 
11
- def __init__(self, model_name: str = "gpt2", load_vocab: bool = True):
12
  self.model_name = model_name
13
- self.pad_token = "<pad>"
14
- self.eos_token = "
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Tokenizer wrapper to handle different tokenizer types with a consistent interface
3
+ """
4
  import os
5
+ import sys
6
  import logging
7
+ from typing import Dict, List, Optional, Union, Any
8
+ import torch
9
 
10
  logger = logging.getLogger(__name__)
11
 
12
+ # Set memory optimization flag
13
+ os.environ["LOW_MEMORY_MODE"] = "1"
14
+
15
  class TokenizerWrapper:
16
+ """A wrapper for tokenizers with common functionality for GPT-2 and BERT models"""
17
 
18
+ def __init__(self, model_name="gpt2", use_fast=True, *args, **kwargs):
19
  self.model_name = model_name
20
+ self.use_fast = use_fast
21
+ self.tokenizer = None
22
+ self._initialize_tokenizer()
23
+
24
+ # Special token defaults
25
+ self.eos_token = "</s>" # Fixed: This was the unterminated string
26
+ self.pad_token = "[PAD]"
27
+ self.unk_token = "[UNK]"
28
+ self.mask_token = "[MASK]"
29
+ self.bos_token = "<s>"
30
+
31
+ # Ensure pad_token is always set (critical for GPT-2)
32
+ self._ensure_pad_token()
33
+
34
+ logger.info(f"Initialized TokenizerWrapper with {model_name}")
35
+
36
+ def _initialize_tokenizer(self):
37
+ """Initialize the actual tokenizer with graceful fallbacks"""
38
+ try:
39
+ from transformers import AutoTokenizer
40
+ self.tokenizer = AutoTokenizer.from_pretrained(
41
+ self.model_name,
42
+ use_fast=self.use_fast
43
+ )
44
+ logger.info(f"Successfully loaded {self.model_name} tokenizer")
45
+ except Exception as e:
46
+ logger.warning(f"Error loading {self.model_name} tokenizer: {e}")
47
+ try:
48
+ # Fallback to GPT-2
49
+ from transformers import AutoTokenizer
50
+ self.tokenizer = AutoTokenizer.from_pretrained("gpt2")
51
+ logger.info("Loaded fallback GPT-2 tokenizer")
52
+ except Exception as e2:
53
+ logger.error(f"Failed to load fallback tokenizer: {e2}")
54
+ # Create minimal placeholder
55
+ self.tokenizer = MinimalTokenizer()
56
+ logger.warning("Using minimal placeholder tokenizer")
57
+
58
+ def _ensure_pad_token(self):
59
+ """Ensure the pad_token is set (especially important for GPT-2)"""
60
+ if not self.tokenizer:
61
+ return
62
+
63
+ if not hasattr(self.tokenizer, 'pad_token') or self.tokenizer.pad_token is None:
64
+ # GPT-2 doesn't have a pad_token by default, use eos_token instead
65
+ if hasattr(self.tokenizer, 'eos_token') and self.tokenizer.eos_token:
66
+ self.tokenizer.pad_token = self.tokenizer.eos_token
67
+ self.pad_token = self.tokenizer.pad_token
68
+ logger.info(f"Set pad_token to eos_token: {self.pad_token}")
69
+ else:
70
+ # Last resort
71
+ self.tokenizer.pad_token = "[PAD]"
72
+ self.pad_token = "[PAD]"
73
+ logger.info("Set default pad_token: [PAD]")
74
+
75
+ @property
76
+ def vocab_size(self) -> int:
77
+ """Get the vocabulary size of the tokenizer"""
78
+ if hasattr(self.tokenizer, 'vocab_size'):
79
+ return self.tokenizer.vocab_size
80
+ elif hasattr(self.tokenizer, 'get_vocab'):
81
+ return len(self.tokenizer.get_vocab())
82
+ return 50257 # Default GPT-2 vocab size
83
+
84
+ @property
85
+ def pad_token_id(self) -> int:
86
+ """Get pad token ID with fallback"""
87
+ if hasattr(self.tokenizer, 'pad_token_id') and self.tokenizer.pad_token_id is not None:
88
+ return self.tokenizer.pad_token_id
89
+ elif hasattr(self.tokenizer, 'eos_token_id') and self.tokenizer.eos_token_id is not None:
90
+ return self.tokenizer.eos_token_id
91
+ return 0 # Last resort fallback
92
+
93
+ @property
94
+ def eos_token_id(self) -> int:
95
+ """Get EOS token ID with fallback"""
96
+ if hasattr(self.tokenizer, 'eos_token_id') and self.tokenizer.eos_token_id is not None:
97
+ return self.tokenizer.eos_token_id
98
+ return 50256 # Default for GPT-2
99
+
100
+ def __call__(self, text, *args, **kwargs):
101
+ """Delegate to the actual tokenizer"""
102
+ if self.tokenizer is None:
103
+ logger.error("Tokenizer not initialized")
104
+ # Create minimal output compatible with model expectations
105
+ if isinstance(text, str):
106
+ # Single string input
107
+ dummy_ids = torch.ones((1, 10), dtype=torch.long)
108
+ return {"input_ids": dummy_ids, "attention_mask": torch.ones_like(dummy_ids)}
109
+ # Batch input
110
+ batch_size = len(text) if isinstance(text, list) else 1
111
+ dummy_ids = torch.ones((batch_size, 10), dtype=torch.long)
112
+ return {"input_ids": dummy_ids, "attention_mask": torch.ones_like(dummy_ids)}
113
+
114
+ return self.tokenizer(text, *args, **kwargs)
115
+
116
+ def encode(self, text, *args, **kwargs):
117
+ """Encode text to token IDs"""
118
+ if self.tokenizer is None:
119
+ logger.error("Tokenizer not initialized")
120
+ if isinstance(text, str):
121
+ return [1] * 10 # Return minimal dummy encoding
122
+ return [[1] * 10 for _ in text] # Batch of dummy encodings
123
+
124
+ return self.tokenizer.encode(text, *args, **kwargs)
125
+
126
+ def decode(self, token_ids, *args, **kwargs):
127
+ """Decode token IDs to text"""
128
+ if self.tokenizer is None:
129
+ logger.error("Tokenizer not initialized")
130
+ return "Error: Tokenizer not initialized"
131
+
132
+ return self.tokenizer.decode(token_ids, *args, **kwargs)
133
+
134
+ def batch_decode(self, sequences, *args, **kwargs):
135
+ """Decode multiple sequences"""
136
+ if self.tokenizer is None:
137
+ logger.error("Tokenizer not initialized")
138
+ return ["Error: Tokenizer not initialized"] * len(sequences)
139
+
140
+ return self.tokenizer.batch_decode(sequences, *args, **kwargs)
141
+
142
+ def __getattr__(self, name):
143
+ """Delegate to the underlying tokenizer for missing attributes"""
144
+ if self.tokenizer is not None and hasattr(self.tokenizer, name):
145
+ return getattr(self.tokenizer, name)
146
+ raise AttributeError(f"'{self.__class__.__name__}' has no attribute '{name}'")
147
+
148
+
149
+ class MinimalTokenizer:
150
+ """Minimal tokenizer implementation for fallback"""
151
+ def __init__(self):
152
+ self.pad_token = "[PAD]"
153
+ self.pad_token_id = 0
154
+ self.eos_token = "</s>"
155
+ self.eos_token_id = 1
156
+ self.bos_token = "<s>"
157
+ self.bos_token_id = 2
158
+ self.unk_token = "[UNK]"
159
+ self.unk_token_id = 3
160
+ self.vocab_size = 50257 # Standard GPT-2 vocab size
161
+ logger.warning("Using minimal placeholder tokenizer with no actual encoding/decoding capability")
162
+
163
+ def __call__(self, text, return_tensors=None, padding=False, truncation=False, max_length=None, *args, **kwargs):
164
+ """Minimal tokenize implementation"""
165
+ # Simple word-splitting tokenizer
166
+ if isinstance(text, str):
167
+ # Handle single string
168
+ tokens = text.split()[:max_length] if max_length else text.split()
169
+ input_ids = [i % (self.vocab_size - 4) + 4 for i in range(len(tokens))]
170
+ if padding and max_length:
171
+ pad_length = max(0, max_length - len(input_ids))
172
+ input_ids = input_ids + [self.pad_token_id] * pad_length
173
+ else:
174
+ # Handle list of strings
175
+ results = []
176
+ max_len = 0
177
+ for t in text:
178
+ tokens = t.split()[:max_length] if max_length else t.split()
179
+ ids = [i % (self.vocab_size - 4) + 4 for i in range(len(tokens))]
180
+ results.append(ids)
181
+ max_len = max(max_len, len(ids))
182
+
183
+ # Pad if needed
184
+ if padding:
185
+ results = [ids + [self.pad_token_id] * (max_len - len(ids)) for ids in results]
186
+
187
+ input_ids = results
188
+
189
+ # Convert to tensor if requested
190
+ if return_tensors == "pt":
191
+ import torch
192
+ if isinstance(input_ids[0], list):
193
+ input_ids = torch.tensor(input_ids, dtype=torch.long)
194
+ attention_mask = torch.ones_like(input_ids)
195
+ else:
196
+ input_ids = torch.tensor([input_ids], dtype=torch.long)
197
+ attention_mask = torch.ones_like(input_ids)
198
+ return {"input_ids": input_ids, "attention_mask": attention_mask}
199
+
200
+ # Return dictionary for compatibility
201
+ return {"input_ids": input_ids, "attention_mask": [1] * len(input_ids)}
202
+
203
+ def encode(self, text, add_special_tokens=True, *args, **kwargs):
204
+ """Minimal encode implementation"""
205
+ if isinstance(text, str):
206
+ tokens = text.split()
207
+ return [i % (self.vocab_size - 4) + 4 for i in range(len(tokens))]
208
+ return [[i % (self.vocab_size - 4) + 4 for i in range(len(t.split()))] for t in text]
209
+
210
+ def decode(self, token_ids, skip_special_tokens=True, *args, **kwargs):
211
+ """Minimal decode implementation"""
212
+ return " ".join(["token" + str(i) for i in token_ids])
213
+
214
+ def batch_decode(self, sequences, skip_special_tokens=True, *args, **kwargs):
215
+ """Minimal batch decode implementation"""
216
+ return [self.decode(seq, skip_special_tokens=skip_special_tokens) for seq in sequences]
217
+
218
+
219
+ def get_tokenizer(model_name="gpt2", use_fast=True):
220
+ """Create a tokenizer for the specified model"""
221
+ # First check registry
222
+ try:
223
+ from service_registry import registry, TOKENIZER
224
+ if registry.has(TOKENIZER):
225
+ logger.info("Retrieved tokenizer from registry")
226
+ return registry.get(TOKENIZER)
227
+ except ImportError:
228
+ pass
229
+
230
+ # Create a new tokenizer
231
+ return TokenizerWrapper(model_name=model_name, use_fast=use_fast)