| | """Advanced Tokenization with Multi-Tokenizer Support and Optimization"""
|
| |
|
| | import json
|
| | import logging
|
| | from dataclasses import dataclass
|
| | from pathlib import Path
|
| | from typing import Any, Dict, List, Optional, Tuple, Union
|
| |
|
| | import numpy as np
|
| | from transformers import AutoTokenizer, PreTrainedTokenizer
|
| | from tokenizers import Tokenizer as HFTokenizer
|
| | from tokenizers.models import WordLevel
|
| | from tokenizers.pre_tokenizers import Whitespace
|
| | from tokenizers.processors import TemplateProcessing
|
| | from tokenizers.trainers import WordLevelTrainer
|
| |
|
| | logger = logging.getLogger(__name__)
|
| |
|
| |
|
| | @dataclass
|
| | class TokenizerConfig:
|
| | """Configuration for advanced tokenizer."""
|
| | tokenizer_name: str = "meta-llama/Llama-2-7b-hf"
|
| | use_custom_tokenizer: bool = False
|
| | custom_vocab_size: int = 32000
|
| | min_frequency: int = 2
|
| | special_tokens: Dict[str, str] = field(default_factory=lambda: {
|
| | "bos_token": "<s>",
|
| | "eos_token": "</s>",
|
| | "pad_token": "<pad>",
|
| | "unk_token": "<unk>",
|
| | "mask_token": "<mask>",
|
| | "system_token": "<system>",
|
| | "user_token": "<user>",
|
| | "assistant_token": "<assistant>",
|
| | "thought_token": "<thought>",
|
| | "/thought_token": "</thought>",
|
| | })
|
| |
|
| |
|
| | use_fast: bool = True
|
| | padding_side: str = "right"
|
| | truncation_side: str = "right"
|
| | model_max_length: int = 32768
|
| |
|
| |
|
| | enable_image_tokenization: bool = False
|
| | enable_audio_tokenization: bool = False
|
| |
|
| |
|
| | class AdvancedTokenizer:
|
| | """Advanced tokenizer with custom training, optimization, and multi-modal support."""
|
| |
|
| | def __init__(self, config: TokenizerConfig):
|
| | self.config = config
|
| | self.tokenizer: Optional[PreTrainedTokenizer] = None
|
| | self._special_tokens = list(config.special_tokens.values())
|
| |
|
| | def load_or_train(self, dataset: Optional[Any] = None) -> PreTrainedTokenizer:
|
| | """Load existing tokenizer or train new one from dataset."""
|
| | if not self.config.use_custom_tokenizer:
|
| | logger.info(f"Loading pretrained tokenizer: {self.config.tokenizer_name}")
|
| | self.tokenizer = AutoTokenizer.from_pretrained(
|
| | self.config.tokenizer_name,
|
| | use_fast=self.config.use_fast,
|
| | padding_side=self.config.padding_side,
|
| | truncation_side=self.config.truncation_side,
|
| | model_max_length=self.config.model_max_length,
|
| | )
|
| | else:
|
| | if dataset is None:
|
| | raise ValueError("Dataset required for custom tokenizer training")
|
| | logger.info("Training custom tokenizer from dataset")
|
| | self.tokenizer = self._train_tokenizer(dataset)
|
| |
|
| |
|
| | self._setup_special_tokens()
|
| |
|
| | return self.tokenizer
|
| |
|
| | def _train_tokenizer(self, dataset: Any) -> PreTrainedTokenizer:
|
| | """Train tokenizer from scratch on dataset."""
|
| |
|
| | import tempfile
|
| | temp_file = tempfile.NamedTemporaryFile(mode='w', suffix='.txt', delete=False)
|
| | temp_file.close()
|
| |
|
| |
|
| | logger.info("Preparing training data...")
|
| | with open(temp_file.name, 'w', encoding='utf-8') as f:
|
| | for sample in dataset:
|
| | text = self._extract_text_for_tokenizer(sample)
|
| | if text:
|
| | f.write(text + '\n')
|
| |
|
| |
|
| | tokenizer = HFTokenizer(WordLevel(unk_token="<unk>"))
|
| | tokenizer.pre_tokenizer = Whitespace()
|
| |
|
| | trainer = WordLevelTrainer(
|
| | vocab_size=self.config.custom_vocab_size,
|
| | min_frequency=self.config.min_frequency,
|
| | special_tokens=self._special_tokens,
|
| | )
|
| |
|
| | logger.info("Training tokenizer...")
|
| | tokenizer.train([temp_file.name], trainer=trainer)
|
| |
|
| |
|
| | from transformers import PreTrainedTokenizerFast
|
| | fast_tokenizer = PreTrainedTokenizerFast(
|
| | tokenizer_object=tokenizer,
|
| | bos_token=self.config.special_tokens["bos_token"],
|
| | eos_token=self.config.special_tokens["eos_token"],
|
| | pad_token=self.config.special_tokens["pad_token"],
|
| | unk_token=self.config.special_tokens["unk_token"],
|
| | mask_token=self.config.special_tokens["mask_token"],
|
| | padding_side=self.config.padding_side,
|
| | truncation_side=self.config.truncation_side,
|
| | model_max_length=self.config.model_max_length,
|
| | )
|
| |
|
| |
|
| | Path(temp_file.name).unlink(missing_ok=True)
|
| |
|
| | logger.info(f"Trained tokenizer with vocab size: {fast_tokenizer.vocab_size}")
|
| | return fast_tokenizer
|
| |
|
| | def _extract_text_for_tokenizer(self, sample: Dict[str, Any]) -> str:
|
| | """Extract text from sample for tokenizer training."""
|
| | if "conversations" in sample:
|
| | conv = sample["conversations"]
|
| | if isinstance(conv, str):
|
| | try:
|
| | conv = json.loads(conv)
|
| | except:
|
| | return conv
|
| | texts = []
|
| | for msg in conv:
|
| | if isinstance(msg, dict):
|
| | role = msg.get("role", "")
|
| | content = msg.get("content", "")
|
| | if content:
|
| |
|
| | if role == "user":
|
| | texts.append(f"{self.config.special_tokens['user_token']} {content}")
|
| | elif role == "assistant":
|
| | texts.append(f"{self.config.special_tokens['assistant_token']} {content}")
|
| | elif role == "system":
|
| | texts.append(f"{self.config.special_tokens['system_token']} {content}")
|
| | else:
|
| | texts.append(content)
|
| | return "\n".join(texts)
|
| | elif "text" in sample:
|
| | return sample["text"]
|
| | elif "content" in sample:
|
| | return sample["content"]
|
| | return ""
|
| |
|
| | def _setup_special_tokens(self):
|
| | """Configure special tokens and post-processing."""
|
| | if self.tokenizer is None:
|
| | raise ValueError("Tokenizer not initialized")
|
| |
|
| |
|
| | special_tokens_dict = {}
|
| | for key, token in self.config.special_tokens.items():
|
| | if token not in self.tokenizer.get_vocab():
|
| | special_tokens_dict[key] = token
|
| |
|
| | if special_tokens_dict:
|
| | self.tokenizer.add_special_tokens(special_tokens_dict)
|
| |
|
| |
|
| | if self.config.use_fast:
|
| | self.tokenizer.chat_template = self._create_chat_template()
|
| |
|
| | def _create_chat_template(self) -> str:
|
| | """Create Jinja2 chat template."""
|
| | template = """{% for message in messages %}
|
| | {% if message['role'] == 'system' %}{{ '{{' }} system {{ '}}' }}{{ message['content'] }}{{ '{{' }} /system {{ '}}' }}
|
| | {% elif message['role'] == 'user' %}{{ '{{' }} user {{ '}}' }}{{ message['content'] }}{{ '{{' }} /user {{ '}}' }}
|
| | {% elif message['role'] == 'assistant' %}{{ '{{' }} assistant {{ '}}' }}{{ message['content'] }}{{ '{{' }} /assistant {{ '}}' }}
|
| | {% endif %}
|
| | {% endfor %}"""
|
| | return template
|
| |
|
| | def tokenize(
|
| | self,
|
| | text: Union[str, List[str]],
|
| | **kwargs
|
| | ) -> Dict[str, Any]:
|
| | """Tokenize text with advanced options."""
|
| | if self.tokenizer is None:
|
| | raise ValueError("Tokenizer not initialized")
|
| |
|
| |
|
| | tokenize_kwargs = {
|
| | "truncation": True,
|
| | "max_length": self.config.model_max_length,
|
| | "padding": "max_length",
|
| | "return_tensors": "pt",
|
| | }
|
| | tokenize_kwargs.update(kwargs)
|
| |
|
| | return self.tokenizer(text, **tokenize_kwargs)
|
| |
|
| | def decode(self, token_ids: Union[List[int], Any], **kwargs) -> str:
|
| | """Decode token IDs to text."""
|
| | if self.tokenizer is None:
|
| | raise ValueError("Tokenizer not initialized")
|
| | return self.tokenizer.decode(token_ids, **kwargs)
|
| |
|
| | def save(self, path: str):
|
| | """Save tokenizer to disk."""
|
| | if self.tokenizer is None:
|
| | raise ValueError("Tokenizer not initialized")
|
| | self.tokenizer.save_pretrained(path)
|
| | logger.info(f"Tokenizer saved to {path}")
|
| |
|
| | @property
|
| | def vocab_size(self) -> int:
|
| | """Get vocabulary size."""
|
| | if self.tokenizer is None:
|
| | return 0
|
| | return self.tokenizer.vocab_size
|
| |
|
| |
|
| | class TokenizerManager:
|
| | """Manages multiple tokenizers for different model sizes."""
|
| |
|
| | def __init__(self):
|
| | self.tokenizers: Dict[str, AdvancedTokenizer] = {}
|
| |
|
| | def register_tokenizer(self, name: str, tokenizer: AdvancedTokenizer):
|
| | """Register a tokenizer."""
|
| | self.tokenizers[name] = tokenizer
|
| |
|
| | def get_tokenizer(self, name: str) -> PreTrainedTokenizer:
|
| | """Get tokenizer by name."""
|
| | if name not in self.tokenizers:
|
| | raise KeyError(f"Tokenizer '{name}' not found")
|
| | return self.tokenizers[name].tokenizer
|
| |
|
| | def load_all(self, dataset: Optional[Any] = None):
|
| | """Load all registered tokenizers."""
|
| | for name, tokenizer in self.tokenizers.items():
|
| | logger.info(f"Loading tokenizer: {name}")
|
| | tokenizer.load_or_train(dataset)
|
| |
|
| | def save_all(self, output_dir: str):
|
| | """Save all tokenizers."""
|
| | base_path = Path(output_dir)
|
| | for name, tokenizer in self.tokenizers.items():
|
| | save_path = base_path / name / "tokenizer"
|
| | tokenizer.save(str(save_path))
|
| |
|
| |
|
| | def create_tokenizer_for_model_size(
|
| | model_size: str,
|
| | config: TokenizerConfig,
|
| | ) -> AdvancedTokenizer:
|
| | """Create tokenizer configured for specific model size."""
|
| | if model_size == "7b":
|
| | config.model_max_length = 8192
|
| | config.tokenizer_name = "meta-llama/Llama-2-7b-hf"
|
| | elif model_size == "32b":
|
| | config.model_max_length = 8192
|
| | config.tokenizer_name = "Qwen/Qwen1.5-32B"
|
| | elif model_size == "70b":
|
| | config.model_max_length = 32768
|
| | config.tokenizer_name = "meta-llama/Llama-2-70b-hf"
|
| | else:
|
| | raise ValueError(f"Unknown model size: {model_size}")
|
| |
|
| | return AdvancedTokenizer(config)
|
| |
|