WildnerveAI's picture
Upload 2 files
53d82e6 verified
"""
Tokenizer wrapper to handle different tokenizer types with a consistent interface
"""
import os
import sys
import logging
from typing import Dict, List, Optional, Union, Any
import torch
logger = logging.getLogger(__name__)
# Set memory optimization flag
os.environ["LOW_MEMORY_MODE"] = "1"
class TokenizerWrapper:
"""A wrapper for tokenizers with common functionality for GPT-2 and BERT models"""
def __init__(self, model_name="gpt2", use_fast=True, *args, **kwargs):
self.model_name = model_name
self.use_fast = use_fast
self.tokenizer = None
self._initialize_tokenizer()
# Special token defaults
self.eos_token = "</s>" # Fixed: This was the unterminated string
self.pad_token = "[PAD]"
self.unk_token = "[UNK]"
self.mask_token = "[MASK]"
self.bos_token = "<s>"
# Ensure pad_token is always set (critical for GPT-2)
self._ensure_pad_token()
logger.info(f"Initialized TokenizerWrapper with {model_name}")
def _initialize_tokenizer(self):
"""Initialize the actual tokenizer with graceful fallbacks"""
try:
from transformers import AutoTokenizer
self.tokenizer = AutoTokenizer.from_pretrained(
self.model_name,
use_fast=self.use_fast
)
logger.info(f"Successfully loaded {self.model_name} tokenizer")
except Exception as e:
logger.warning(f"Error loading {self.model_name} tokenizer: {e}")
try:
# Fallback to GPT-2
from transformers import AutoTokenizer
self.tokenizer = AutoTokenizer.from_pretrained("gpt2")
logger.info("Loaded fallback GPT-2 tokenizer")
except Exception as e2:
logger.error(f"Failed to load fallback tokenizer: {e2}")
# Create minimal placeholder
self.tokenizer = MinimalTokenizer()
logger.warning("Using minimal placeholder tokenizer")
def _ensure_pad_token(self):
"""Ensure the pad_token is set (especially important for GPT-2)"""
if not self.tokenizer:
return
if not hasattr(self.tokenizer, 'pad_token') or self.tokenizer.pad_token is None:
# GPT-2 doesn't have a pad_token by default, use eos_token instead
if hasattr(self.tokenizer, 'eos_token') and self.tokenizer.eos_token:
self.tokenizer.pad_token = self.tokenizer.eos_token
self.pad_token = self.tokenizer.pad_token
logger.info(f"Set pad_token to eos_token: {self.pad_token}")
else:
# Last resort
self.tokenizer.pad_token = "[PAD]"
self.pad_token = "[PAD]"
logger.info("Set default pad_token: [PAD]")
@property
def vocab_size(self) -> int:
"""Get the vocabulary size of the tokenizer"""
if hasattr(self.tokenizer, 'vocab_size'):
return self.tokenizer.vocab_size
elif hasattr(self.tokenizer, 'get_vocab'):
return len(self.tokenizer.get_vocab())
return 50257 # Default GPT-2 vocab size
@property
def pad_token_id(self) -> int:
"""Get pad token ID with fallback"""
if hasattr(self.tokenizer, 'pad_token_id') and self.tokenizer.pad_token_id is not None:
return self.tokenizer.pad_token_id
elif hasattr(self.tokenizer, 'eos_token_id') and self.tokenizer.eos_token_id is not None:
return self.tokenizer.eos_token_id
return 0 # Last resort fallback
@property
def eos_token_id(self) -> int:
"""Get EOS token ID with fallback"""
if hasattr(self.tokenizer, 'eos_token_id') and self.tokenizer.eos_token_id is not None:
return self.tokenizer.eos_token_id
return 50256 # Default for GPT-2
def __call__(self, text, *args, **kwargs):
"""Delegate to the actual tokenizer"""
if self.tokenizer is None:
logger.error("Tokenizer not initialized")
# Create minimal output compatible with model expectations
if isinstance(text, str):
# Single string input
dummy_ids = torch.ones((1, 10), dtype=torch.long)
return {"input_ids": dummy_ids, "attention_mask": torch.ones_like(dummy_ids)}
# Batch input
batch_size = len(text) if isinstance(text, list) else 1
dummy_ids = torch.ones((batch_size, 10), dtype=torch.long)
return {"input_ids": dummy_ids, "attention_mask": torch.ones_like(dummy_ids)}
return self.tokenizer(text, *args, **kwargs)
def encode(self, text, *args, **kwargs):
"""Encode text to token IDs"""
if self.tokenizer is None:
logger.error("Tokenizer not initialized")
if isinstance(text, str):
return [1] * 10 # Return minimal dummy encoding
return [[1] * 10 for _ in text] # Batch of dummy encodings
return self.tokenizer.encode(text, *args, **kwargs)
def decode(self, token_ids, *args, **kwargs):
"""Decode token IDs to text"""
if self.tokenizer is None:
logger.error("Tokenizer not initialized")
return "Error: Tokenizer not initialized"
return self.tokenizer.decode(token_ids, *args, **kwargs)
def batch_decode(self, sequences, *args, **kwargs):
"""Decode multiple sequences"""
if self.tokenizer is None:
logger.error("Tokenizer not initialized")
return ["Error: Tokenizer not initialized"] * len(sequences)
return self.tokenizer.batch_decode(sequences, *args, **kwargs)
def __getattr__(self, name):
"""Delegate to the underlying tokenizer for missing attributes"""
if self.tokenizer is not None and hasattr(self.tokenizer, name):
return getattr(self.tokenizer, name)
raise AttributeError(f"'{self.__class__.__name__}' has no attribute '{name}'")
class MinimalTokenizer:
"""Minimal tokenizer implementation for fallback"""
def __init__(self):
self.pad_token = "[PAD]"
self.pad_token_id = 0
self.eos_token = "</s>"
self.eos_token_id = 1
self.bos_token = "<s>"
self.bos_token_id = 2
self.unk_token = "[UNK]"
self.unk_token_id = 3
self.vocab_size = 50257 # Standard GPT-2 vocab size
logger.warning("Using minimal placeholder tokenizer with no actual encoding/decoding capability")
def __call__(self, text, return_tensors=None, padding=False, truncation=False, max_length=None, *args, **kwargs):
"""Minimal tokenize implementation"""
# Simple word-splitting tokenizer
if isinstance(text, str):
# Handle single string
tokens = text.split()[:max_length] if max_length else text.split()
input_ids = [i % (self.vocab_size - 4) + 4 for i in range(len(tokens))]
if padding and max_length:
pad_length = max(0, max_length - len(input_ids))
input_ids = input_ids + [self.pad_token_id] * pad_length
else:
# Handle list of strings
results = []
max_len = 0
for t in text:
tokens = t.split()[:max_length] if max_length else t.split()
ids = [i % (self.vocab_size - 4) + 4 for i in range(len(tokens))]
results.append(ids)
max_len = max(max_len, len(ids))
# Pad if needed
if padding:
results = [ids + [self.pad_token_id] * (max_len - len(ids)) for ids in results]
input_ids = results
# Convert to tensor if requested
if return_tensors == "pt":
import torch
if isinstance(input_ids[0], list):
input_ids = torch.tensor(input_ids, dtype=torch.long)
attention_mask = torch.ones_like(input_ids)
else:
input_ids = torch.tensor([input_ids], dtype=torch.long)
attention_mask = torch.ones_like(input_ids)
return {"input_ids": input_ids, "attention_mask": attention_mask}
# Return dictionary for compatibility
return {"input_ids": input_ids, "attention_mask": [1] * len(input_ids)}
def encode(self, text, add_special_tokens=True, *args, **kwargs):
"""Minimal encode implementation"""
if isinstance(text, str):
tokens = text.split()
return [i % (self.vocab_size - 4) + 4 for i in range(len(tokens))]
return [[i % (self.vocab_size - 4) + 4 for i in range(len(t.split()))] for t in text]
def decode(self, token_ids, skip_special_tokens=True, *args, **kwargs):
"""Minimal decode implementation"""
return " ".join(["token" + str(i) for i in token_ids])
def batch_decode(self, sequences, skip_special_tokens=True, *args, **kwargs):
"""Minimal batch decode implementation"""
return [self.decode(seq, skip_special_tokens=skip_special_tokens) for seq in sequences]
def get_tokenizer(model_name="gpt2", use_fast=True):
"""Create a tokenizer for the specified model"""
# First check registry
try:
from service_registry import registry, TOKENIZER
if registry.has(TOKENIZER):
logger.info("Retrieved tokenizer from registry")
return registry.get(TOKENIZER)
except ImportError:
pass
# Create a new tokenizer
return TokenizerWrapper(model_name=model_name, use_fast=use_fast)