""" Tokenizer wrapper to handle different tokenizer types with a consistent interface """ import os import sys import logging from typing import Dict, List, Optional, Union, Any import torch logger = logging.getLogger(__name__) # Set memory optimization flag os.environ["LOW_MEMORY_MODE"] = "1" class TokenizerWrapper: """A wrapper for tokenizers with common functionality for GPT-2 and BERT models""" def __init__(self, model_name="gpt2", use_fast=True, *args, **kwargs): self.model_name = model_name self.use_fast = use_fast self.tokenizer = None self._initialize_tokenizer() # Special token defaults self.eos_token = "" # Fixed: This was the unterminated string self.pad_token = "[PAD]" self.unk_token = "[UNK]" self.mask_token = "[MASK]" self.bos_token = "" # Ensure pad_token is always set (critical for GPT-2) self._ensure_pad_token() logger.info(f"Initialized TokenizerWrapper with {model_name}") def _initialize_tokenizer(self): """Initialize the actual tokenizer with graceful fallbacks""" try: from transformers import AutoTokenizer self.tokenizer = AutoTokenizer.from_pretrained( self.model_name, use_fast=self.use_fast ) logger.info(f"Successfully loaded {self.model_name} tokenizer") except Exception as e: logger.warning(f"Error loading {self.model_name} tokenizer: {e}") try: # Fallback to GPT-2 from transformers import AutoTokenizer self.tokenizer = AutoTokenizer.from_pretrained("gpt2") logger.info("Loaded fallback GPT-2 tokenizer") except Exception as e2: logger.error(f"Failed to load fallback tokenizer: {e2}") # Create minimal placeholder self.tokenizer = MinimalTokenizer() logger.warning("Using minimal placeholder tokenizer") def _ensure_pad_token(self): """Ensure the pad_token is set (especially important for GPT-2)""" if not self.tokenizer: return if not hasattr(self.tokenizer, 'pad_token') or self.tokenizer.pad_token is None: # GPT-2 doesn't have a pad_token by default, use eos_token instead if hasattr(self.tokenizer, 'eos_token') and self.tokenizer.eos_token: self.tokenizer.pad_token = self.tokenizer.eos_token self.pad_token = self.tokenizer.pad_token logger.info(f"Set pad_token to eos_token: {self.pad_token}") else: # Last resort self.tokenizer.pad_token = "[PAD]" self.pad_token = "[PAD]" logger.info("Set default pad_token: [PAD]") @property def vocab_size(self) -> int: """Get the vocabulary size of the tokenizer""" if hasattr(self.tokenizer, 'vocab_size'): return self.tokenizer.vocab_size elif hasattr(self.tokenizer, 'get_vocab'): return len(self.tokenizer.get_vocab()) return 50257 # Default GPT-2 vocab size @property def pad_token_id(self) -> int: """Get pad token ID with fallback""" if hasattr(self.tokenizer, 'pad_token_id') and self.tokenizer.pad_token_id is not None: return self.tokenizer.pad_token_id elif hasattr(self.tokenizer, 'eos_token_id') and self.tokenizer.eos_token_id is not None: return self.tokenizer.eos_token_id return 0 # Last resort fallback @property def eos_token_id(self) -> int: """Get EOS token ID with fallback""" if hasattr(self.tokenizer, 'eos_token_id') and self.tokenizer.eos_token_id is not None: return self.tokenizer.eos_token_id return 50256 # Default for GPT-2 def __call__(self, text, *args, **kwargs): """Delegate to the actual tokenizer""" if self.tokenizer is None: logger.error("Tokenizer not initialized") # Create minimal output compatible with model expectations if isinstance(text, str): # Single string input dummy_ids = torch.ones((1, 10), dtype=torch.long) return {"input_ids": dummy_ids, "attention_mask": torch.ones_like(dummy_ids)} # Batch input batch_size = len(text) if isinstance(text, list) else 1 dummy_ids = torch.ones((batch_size, 10), dtype=torch.long) return {"input_ids": dummy_ids, "attention_mask": torch.ones_like(dummy_ids)} return self.tokenizer(text, *args, **kwargs) def encode(self, text, *args, **kwargs): """Encode text to token IDs""" if self.tokenizer is None: logger.error("Tokenizer not initialized") if isinstance(text, str): return [1] * 10 # Return minimal dummy encoding return [[1] * 10 for _ in text] # Batch of dummy encodings return self.tokenizer.encode(text, *args, **kwargs) def decode(self, token_ids, *args, **kwargs): """Decode token IDs to text""" if self.tokenizer is None: logger.error("Tokenizer not initialized") return "Error: Tokenizer not initialized" return self.tokenizer.decode(token_ids, *args, **kwargs) def batch_decode(self, sequences, *args, **kwargs): """Decode multiple sequences""" if self.tokenizer is None: logger.error("Tokenizer not initialized") return ["Error: Tokenizer not initialized"] * len(sequences) return self.tokenizer.batch_decode(sequences, *args, **kwargs) def __getattr__(self, name): """Delegate to the underlying tokenizer for missing attributes""" if self.tokenizer is not None and hasattr(self.tokenizer, name): return getattr(self.tokenizer, name) raise AttributeError(f"'{self.__class__.__name__}' has no attribute '{name}'") class MinimalTokenizer: """Minimal tokenizer implementation for fallback""" def __init__(self): self.pad_token = "[PAD]" self.pad_token_id = 0 self.eos_token = "" self.eos_token_id = 1 self.bos_token = "" self.bos_token_id = 2 self.unk_token = "[UNK]" self.unk_token_id = 3 self.vocab_size = 50257 # Standard GPT-2 vocab size logger.warning("Using minimal placeholder tokenizer with no actual encoding/decoding capability") def __call__(self, text, return_tensors=None, padding=False, truncation=False, max_length=None, *args, **kwargs): """Minimal tokenize implementation""" # Simple word-splitting tokenizer if isinstance(text, str): # Handle single string tokens = text.split()[:max_length] if max_length else text.split() input_ids = [i % (self.vocab_size - 4) + 4 for i in range(len(tokens))] if padding and max_length: pad_length = max(0, max_length - len(input_ids)) input_ids = input_ids + [self.pad_token_id] * pad_length else: # Handle list of strings results = [] max_len = 0 for t in text: tokens = t.split()[:max_length] if max_length else t.split() ids = [i % (self.vocab_size - 4) + 4 for i in range(len(tokens))] results.append(ids) max_len = max(max_len, len(ids)) # Pad if needed if padding: results = [ids + [self.pad_token_id] * (max_len - len(ids)) for ids in results] input_ids = results # Convert to tensor if requested if return_tensors == "pt": import torch if isinstance(input_ids[0], list): input_ids = torch.tensor(input_ids, dtype=torch.long) attention_mask = torch.ones_like(input_ids) else: input_ids = torch.tensor([input_ids], dtype=torch.long) attention_mask = torch.ones_like(input_ids) return {"input_ids": input_ids, "attention_mask": attention_mask} # Return dictionary for compatibility return {"input_ids": input_ids, "attention_mask": [1] * len(input_ids)} def encode(self, text, add_special_tokens=True, *args, **kwargs): """Minimal encode implementation""" if isinstance(text, str): tokens = text.split() return [i % (self.vocab_size - 4) + 4 for i in range(len(tokens))] return [[i % (self.vocab_size - 4) + 4 for i in range(len(t.split()))] for t in text] def decode(self, token_ids, skip_special_tokens=True, *args, **kwargs): """Minimal decode implementation""" return " ".join(["token" + str(i) for i in token_ids]) def batch_decode(self, sequences, skip_special_tokens=True, *args, **kwargs): """Minimal batch decode implementation""" return [self.decode(seq, skip_special_tokens=skip_special_tokens) for seq in sequences] def get_tokenizer(model_name="gpt2", use_fast=True): """Create a tokenizer for the specified model""" # First check registry try: from service_registry import registry, TOKENIZER if registry.has(TOKENIZER): logger.info("Retrieved tokenizer from registry") return registry.get(TOKENIZER) except ImportError: pass # Create a new tokenizer return TokenizerWrapper(model_name=model_name, use_fast=use_fast)