File size: 10,064 Bytes
53d82e6 8e8be81 53d82e6 8e8be81 53d82e6 8e8be81 53d82e6 8e8be81 53d82e6 0f72521 53d82e6 0f72521 53d82e6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 |
"""
Tokenizer wrapper to handle different tokenizer types with a consistent interface
"""
import os
import sys
import logging
from typing import Dict, List, Optional, Union, Any
import torch
logger = logging.getLogger(__name__)
# Set memory optimization flag
os.environ["LOW_MEMORY_MODE"] = "1"
class TokenizerWrapper:
"""A wrapper for tokenizers with common functionality for GPT-2 and BERT models"""
def __init__(self, model_name="gpt2", use_fast=True, *args, **kwargs):
self.model_name = model_name
self.use_fast = use_fast
self.tokenizer = None
self._initialize_tokenizer()
# Special token defaults
self.eos_token = "</s>" # Fixed: This was the unterminated string
self.pad_token = "[PAD]"
self.unk_token = "[UNK]"
self.mask_token = "[MASK]"
self.bos_token = "<s>"
# Ensure pad_token is always set (critical for GPT-2)
self._ensure_pad_token()
logger.info(f"Initialized TokenizerWrapper with {model_name}")
def _initialize_tokenizer(self):
"""Initialize the actual tokenizer with graceful fallbacks"""
try:
from transformers import AutoTokenizer
self.tokenizer = AutoTokenizer.from_pretrained(
self.model_name,
use_fast=self.use_fast
)
logger.info(f"Successfully loaded {self.model_name} tokenizer")
except Exception as e:
logger.warning(f"Error loading {self.model_name} tokenizer: {e}")
try:
# Fallback to GPT-2
from transformers import AutoTokenizer
self.tokenizer = AutoTokenizer.from_pretrained("gpt2")
logger.info("Loaded fallback GPT-2 tokenizer")
except Exception as e2:
logger.error(f"Failed to load fallback tokenizer: {e2}")
# Create minimal placeholder
self.tokenizer = MinimalTokenizer()
logger.warning("Using minimal placeholder tokenizer")
def _ensure_pad_token(self):
"""Ensure the pad_token is set (especially important for GPT-2)"""
if not self.tokenizer:
return
if not hasattr(self.tokenizer, 'pad_token') or self.tokenizer.pad_token is None:
# GPT-2 doesn't have a pad_token by default, use eos_token instead
if hasattr(self.tokenizer, 'eos_token') and self.tokenizer.eos_token:
self.tokenizer.pad_token = self.tokenizer.eos_token
self.pad_token = self.tokenizer.pad_token
logger.info(f"Set pad_token to eos_token: {self.pad_token}")
else:
# Last resort
self.tokenizer.pad_token = "[PAD]"
self.pad_token = "[PAD]"
logger.info("Set default pad_token: [PAD]")
@property
def vocab_size(self) -> int:
"""Get the vocabulary size of the tokenizer"""
if hasattr(self.tokenizer, 'vocab_size'):
return self.tokenizer.vocab_size
elif hasattr(self.tokenizer, 'get_vocab'):
return len(self.tokenizer.get_vocab())
return 50257 # Default GPT-2 vocab size
@property
def pad_token_id(self) -> int:
"""Get pad token ID with fallback"""
if hasattr(self.tokenizer, 'pad_token_id') and self.tokenizer.pad_token_id is not None:
return self.tokenizer.pad_token_id
elif hasattr(self.tokenizer, 'eos_token_id') and self.tokenizer.eos_token_id is not None:
return self.tokenizer.eos_token_id
return 0 # Last resort fallback
@property
def eos_token_id(self) -> int:
"""Get EOS token ID with fallback"""
if hasattr(self.tokenizer, 'eos_token_id') and self.tokenizer.eos_token_id is not None:
return self.tokenizer.eos_token_id
return 50256 # Default for GPT-2
def __call__(self, text, *args, **kwargs):
"""Delegate to the actual tokenizer"""
if self.tokenizer is None:
logger.error("Tokenizer not initialized")
# Create minimal output compatible with model expectations
if isinstance(text, str):
# Single string input
dummy_ids = torch.ones((1, 10), dtype=torch.long)
return {"input_ids": dummy_ids, "attention_mask": torch.ones_like(dummy_ids)}
# Batch input
batch_size = len(text) if isinstance(text, list) else 1
dummy_ids = torch.ones((batch_size, 10), dtype=torch.long)
return {"input_ids": dummy_ids, "attention_mask": torch.ones_like(dummy_ids)}
return self.tokenizer(text, *args, **kwargs)
def encode(self, text, *args, **kwargs):
"""Encode text to token IDs"""
if self.tokenizer is None:
logger.error("Tokenizer not initialized")
if isinstance(text, str):
return [1] * 10 # Return minimal dummy encoding
return [[1] * 10 for _ in text] # Batch of dummy encodings
return self.tokenizer.encode(text, *args, **kwargs)
def decode(self, token_ids, *args, **kwargs):
"""Decode token IDs to text"""
if self.tokenizer is None:
logger.error("Tokenizer not initialized")
return "Error: Tokenizer not initialized"
return self.tokenizer.decode(token_ids, *args, **kwargs)
def batch_decode(self, sequences, *args, **kwargs):
"""Decode multiple sequences"""
if self.tokenizer is None:
logger.error("Tokenizer not initialized")
return ["Error: Tokenizer not initialized"] * len(sequences)
return self.tokenizer.batch_decode(sequences, *args, **kwargs)
def __getattr__(self, name):
"""Delegate to the underlying tokenizer for missing attributes"""
if self.tokenizer is not None and hasattr(self.tokenizer, name):
return getattr(self.tokenizer, name)
raise AttributeError(f"'{self.__class__.__name__}' has no attribute '{name}'")
class MinimalTokenizer:
"""Minimal tokenizer implementation for fallback"""
def __init__(self):
self.pad_token = "[PAD]"
self.pad_token_id = 0
self.eos_token = "</s>"
self.eos_token_id = 1
self.bos_token = "<s>"
self.bos_token_id = 2
self.unk_token = "[UNK]"
self.unk_token_id = 3
self.vocab_size = 50257 # Standard GPT-2 vocab size
logger.warning("Using minimal placeholder tokenizer with no actual encoding/decoding capability")
def __call__(self, text, return_tensors=None, padding=False, truncation=False, max_length=None, *args, **kwargs):
"""Minimal tokenize implementation"""
# Simple word-splitting tokenizer
if isinstance(text, str):
# Handle single string
tokens = text.split()[:max_length] if max_length else text.split()
input_ids = [i % (self.vocab_size - 4) + 4 for i in range(len(tokens))]
if padding and max_length:
pad_length = max(0, max_length - len(input_ids))
input_ids = input_ids + [self.pad_token_id] * pad_length
else:
# Handle list of strings
results = []
max_len = 0
for t in text:
tokens = t.split()[:max_length] if max_length else t.split()
ids = [i % (self.vocab_size - 4) + 4 for i in range(len(tokens))]
results.append(ids)
max_len = max(max_len, len(ids))
# Pad if needed
if padding:
results = [ids + [self.pad_token_id] * (max_len - len(ids)) for ids in results]
input_ids = results
# Convert to tensor if requested
if return_tensors == "pt":
import torch
if isinstance(input_ids[0], list):
input_ids = torch.tensor(input_ids, dtype=torch.long)
attention_mask = torch.ones_like(input_ids)
else:
input_ids = torch.tensor([input_ids], dtype=torch.long)
attention_mask = torch.ones_like(input_ids)
return {"input_ids": input_ids, "attention_mask": attention_mask}
# Return dictionary for compatibility
return {"input_ids": input_ids, "attention_mask": [1] * len(input_ids)}
def encode(self, text, add_special_tokens=True, *args, **kwargs):
"""Minimal encode implementation"""
if isinstance(text, str):
tokens = text.split()
return [i % (self.vocab_size - 4) + 4 for i in range(len(tokens))]
return [[i % (self.vocab_size - 4) + 4 for i in range(len(t.split()))] for t in text]
def decode(self, token_ids, skip_special_tokens=True, *args, **kwargs):
"""Minimal decode implementation"""
return " ".join(["token" + str(i) for i in token_ids])
def batch_decode(self, sequences, skip_special_tokens=True, *args, **kwargs):
"""Minimal batch decode implementation"""
return [self.decode(seq, skip_special_tokens=skip_special_tokens) for seq in sequences]
def get_tokenizer(model_name="gpt2", use_fast=True):
"""Create a tokenizer for the specified model"""
# First check registry
try:
from service_registry import registry, TOKENIZER
if registry.has(TOKENIZER):
logger.info("Retrieved tokenizer from registry")
return registry.get(TOKENIZER)
except ImportError:
pass
# Create a new tokenizer
return TokenizerWrapper(model_name=model_name, use_fast=use_fast) |