|
|
"""
|
|
|
Tokenizer wrapper to handle different tokenizer types with a consistent interface
|
|
|
"""
|
|
|
import os
|
|
|
import sys
|
|
|
import logging
|
|
|
from typing import Dict, List, Optional, Union, Any
|
|
|
import torch
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
os.environ["LOW_MEMORY_MODE"] = "1"
|
|
|
|
|
|
class TokenizerWrapper:
|
|
|
"""A wrapper for tokenizers with common functionality for GPT-2 and BERT models"""
|
|
|
|
|
|
def __init__(self, model_name="gpt2", use_fast=True, *args, **kwargs):
|
|
|
self.model_name = model_name
|
|
|
self.use_fast = use_fast
|
|
|
self.tokenizer = None
|
|
|
self._initialize_tokenizer()
|
|
|
|
|
|
|
|
|
self.eos_token = "</s>"
|
|
|
self.pad_token = "[PAD]"
|
|
|
self.unk_token = "[UNK]"
|
|
|
self.mask_token = "[MASK]"
|
|
|
self.bos_token = "<s>"
|
|
|
|
|
|
|
|
|
self._ensure_pad_token()
|
|
|
|
|
|
logger.info(f"Initialized TokenizerWrapper with {model_name}")
|
|
|
|
|
|
def _initialize_tokenizer(self):
|
|
|
"""Initialize the actual tokenizer with graceful fallbacks"""
|
|
|
try:
|
|
|
from transformers import AutoTokenizer
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained(
|
|
|
self.model_name,
|
|
|
use_fast=self.use_fast
|
|
|
)
|
|
|
logger.info(f"Successfully loaded {self.model_name} tokenizer")
|
|
|
except Exception as e:
|
|
|
logger.warning(f"Error loading {self.model_name} tokenizer: {e}")
|
|
|
try:
|
|
|
|
|
|
from transformers import AutoTokenizer
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained("gpt2")
|
|
|
logger.info("Loaded fallback GPT-2 tokenizer")
|
|
|
except Exception as e2:
|
|
|
logger.error(f"Failed to load fallback tokenizer: {e2}")
|
|
|
|
|
|
self.tokenizer = MinimalTokenizer()
|
|
|
logger.warning("Using minimal placeholder tokenizer")
|
|
|
|
|
|
def _ensure_pad_token(self):
|
|
|
"""Ensure the pad_token is set (especially important for GPT-2)"""
|
|
|
if not self.tokenizer:
|
|
|
return
|
|
|
|
|
|
if not hasattr(self.tokenizer, 'pad_token') or self.tokenizer.pad_token is None:
|
|
|
|
|
|
if hasattr(self.tokenizer, 'eos_token') and self.tokenizer.eos_token:
|
|
|
self.tokenizer.pad_token = self.tokenizer.eos_token
|
|
|
self.pad_token = self.tokenizer.pad_token
|
|
|
logger.info(f"Set pad_token to eos_token: {self.pad_token}")
|
|
|
else:
|
|
|
|
|
|
self.tokenizer.pad_token = "[PAD]"
|
|
|
self.pad_token = "[PAD]"
|
|
|
logger.info("Set default pad_token: [PAD]")
|
|
|
|
|
|
@property
|
|
|
def vocab_size(self) -> int:
|
|
|
"""Get the vocabulary size of the tokenizer"""
|
|
|
if hasattr(self.tokenizer, 'vocab_size'):
|
|
|
return self.tokenizer.vocab_size
|
|
|
elif hasattr(self.tokenizer, 'get_vocab'):
|
|
|
return len(self.tokenizer.get_vocab())
|
|
|
return 50257
|
|
|
|
|
|
@property
|
|
|
def pad_token_id(self) -> int:
|
|
|
"""Get pad token ID with fallback"""
|
|
|
if hasattr(self.tokenizer, 'pad_token_id') and self.tokenizer.pad_token_id is not None:
|
|
|
return self.tokenizer.pad_token_id
|
|
|
elif hasattr(self.tokenizer, 'eos_token_id') and self.tokenizer.eos_token_id is not None:
|
|
|
return self.tokenizer.eos_token_id
|
|
|
return 0
|
|
|
|
|
|
@property
|
|
|
def eos_token_id(self) -> int:
|
|
|
"""Get EOS token ID with fallback"""
|
|
|
if hasattr(self.tokenizer, 'eos_token_id') and self.tokenizer.eos_token_id is not None:
|
|
|
return self.tokenizer.eos_token_id
|
|
|
return 50256
|
|
|
|
|
|
def __call__(self, text, *args, **kwargs):
|
|
|
"""Delegate to the actual tokenizer"""
|
|
|
if self.tokenizer is None:
|
|
|
logger.error("Tokenizer not initialized")
|
|
|
|
|
|
if isinstance(text, str):
|
|
|
|
|
|
dummy_ids = torch.ones((1, 10), dtype=torch.long)
|
|
|
return {"input_ids": dummy_ids, "attention_mask": torch.ones_like(dummy_ids)}
|
|
|
|
|
|
batch_size = len(text) if isinstance(text, list) else 1
|
|
|
dummy_ids = torch.ones((batch_size, 10), dtype=torch.long)
|
|
|
return {"input_ids": dummy_ids, "attention_mask": torch.ones_like(dummy_ids)}
|
|
|
|
|
|
return self.tokenizer(text, *args, **kwargs)
|
|
|
|
|
|
def encode(self, text, *args, **kwargs):
|
|
|
"""Encode text to token IDs"""
|
|
|
if self.tokenizer is None:
|
|
|
logger.error("Tokenizer not initialized")
|
|
|
if isinstance(text, str):
|
|
|
return [1] * 10
|
|
|
return [[1] * 10 for _ in text]
|
|
|
|
|
|
return self.tokenizer.encode(text, *args, **kwargs)
|
|
|
|
|
|
def decode(self, token_ids, *args, **kwargs):
|
|
|
"""Decode token IDs to text"""
|
|
|
if self.tokenizer is None:
|
|
|
logger.error("Tokenizer not initialized")
|
|
|
return "Error: Tokenizer not initialized"
|
|
|
|
|
|
return self.tokenizer.decode(token_ids, *args, **kwargs)
|
|
|
|
|
|
def batch_decode(self, sequences, *args, **kwargs):
|
|
|
"""Decode multiple sequences"""
|
|
|
if self.tokenizer is None:
|
|
|
logger.error("Tokenizer not initialized")
|
|
|
return ["Error: Tokenizer not initialized"] * len(sequences)
|
|
|
|
|
|
return self.tokenizer.batch_decode(sequences, *args, **kwargs)
|
|
|
|
|
|
def __getattr__(self, name):
|
|
|
"""Delegate to the underlying tokenizer for missing attributes"""
|
|
|
if self.tokenizer is not None and hasattr(self.tokenizer, name):
|
|
|
return getattr(self.tokenizer, name)
|
|
|
raise AttributeError(f"'{self.__class__.__name__}' has no attribute '{name}'")
|
|
|
|
|
|
|
|
|
class MinimalTokenizer:
|
|
|
"""Minimal tokenizer implementation for fallback"""
|
|
|
def __init__(self):
|
|
|
self.pad_token = "[PAD]"
|
|
|
self.pad_token_id = 0
|
|
|
self.eos_token = "</s>"
|
|
|
self.eos_token_id = 1
|
|
|
self.bos_token = "<s>"
|
|
|
self.bos_token_id = 2
|
|
|
self.unk_token = "[UNK]"
|
|
|
self.unk_token_id = 3
|
|
|
self.vocab_size = 50257
|
|
|
logger.warning("Using minimal placeholder tokenizer with no actual encoding/decoding capability")
|
|
|
|
|
|
def __call__(self, text, return_tensors=None, padding=False, truncation=False, max_length=None, *args, **kwargs):
|
|
|
"""Minimal tokenize implementation"""
|
|
|
|
|
|
if isinstance(text, str):
|
|
|
|
|
|
tokens = text.split()[:max_length] if max_length else text.split()
|
|
|
input_ids = [i % (self.vocab_size - 4) + 4 for i in range(len(tokens))]
|
|
|
if padding and max_length:
|
|
|
pad_length = max(0, max_length - len(input_ids))
|
|
|
input_ids = input_ids + [self.pad_token_id] * pad_length
|
|
|
else:
|
|
|
|
|
|
results = []
|
|
|
max_len = 0
|
|
|
for t in text:
|
|
|
tokens = t.split()[:max_length] if max_length else t.split()
|
|
|
ids = [i % (self.vocab_size - 4) + 4 for i in range(len(tokens))]
|
|
|
results.append(ids)
|
|
|
max_len = max(max_len, len(ids))
|
|
|
|
|
|
|
|
|
if padding:
|
|
|
results = [ids + [self.pad_token_id] * (max_len - len(ids)) for ids in results]
|
|
|
|
|
|
input_ids = results
|
|
|
|
|
|
|
|
|
if return_tensors == "pt":
|
|
|
import torch
|
|
|
if isinstance(input_ids[0], list):
|
|
|
input_ids = torch.tensor(input_ids, dtype=torch.long)
|
|
|
attention_mask = torch.ones_like(input_ids)
|
|
|
else:
|
|
|
input_ids = torch.tensor([input_ids], dtype=torch.long)
|
|
|
attention_mask = torch.ones_like(input_ids)
|
|
|
return {"input_ids": input_ids, "attention_mask": attention_mask}
|
|
|
|
|
|
|
|
|
return {"input_ids": input_ids, "attention_mask": [1] * len(input_ids)}
|
|
|
|
|
|
def encode(self, text, add_special_tokens=True, *args, **kwargs):
|
|
|
"""Minimal encode implementation"""
|
|
|
if isinstance(text, str):
|
|
|
tokens = text.split()
|
|
|
return [i % (self.vocab_size - 4) + 4 for i in range(len(tokens))]
|
|
|
return [[i % (self.vocab_size - 4) + 4 for i in range(len(t.split()))] for t in text]
|
|
|
|
|
|
def decode(self, token_ids, skip_special_tokens=True, *args, **kwargs):
|
|
|
"""Minimal decode implementation"""
|
|
|
return " ".join(["token" + str(i) for i in token_ids])
|
|
|
|
|
|
def batch_decode(self, sequences, skip_special_tokens=True, *args, **kwargs):
|
|
|
"""Minimal batch decode implementation"""
|
|
|
return [self.decode(seq, skip_special_tokens=skip_special_tokens) for seq in sequences]
|
|
|
|
|
|
|
|
|
def get_tokenizer(model_name="gpt2", use_fast=True):
|
|
|
"""Create a tokenizer for the specified model"""
|
|
|
|
|
|
try:
|
|
|
from service_registry import registry, TOKENIZER
|
|
|
if registry.has(TOKENIZER):
|
|
|
logger.info("Retrieved tokenizer from registry")
|
|
|
return registry.get(TOKENIZER)
|
|
|
except ImportError:
|
|
|
pass
|
|
|
|
|
|
|
|
|
return TokenizerWrapper(model_name=model_name, use_fast=use_fast) |