"""Tokenizer training utilities.""" import os import json import tempfile from pathlib import Path from typing import Optional, Dict, Any from taoTrain.data.chunk_manager import ChunkManager class TokenizerTrainer: """Train SentencePiece tokenizers from JSONL data.""" @staticmethod def train_from_config(config: "TokenizerConfig") -> dict: # type: ignore """ Train a tokenizer from a TokenizerConfig object. Args: config: TokenizerConfig instance Returns: Dict with paths to generated tokenizer files """ # Build special tokens string for SentencePiece if provided user_defined_symbols = None if config.special_tokens: # Sort by ID and format as comma-separated tokens sorted_tokens = sorted(config.special_tokens.items(), key=lambda x: x[1]) user_defined_symbols = ','.join([token for token, _ in sorted_tokens]) return TokenizerTrainer.train_sentencepiece( jsonl_path=config.jsonl_path, output_dir=config.output_dir, vocab_size=config.vocab_size, model_type=config.model_type, character_coverage=config.character_coverage, unk_id=config.unk_id, bos_id=config.bos_id, eos_id=config.eos_id, pad_id=config.pad_id, tokenizer_prefix=config.tokenizer_prefix, text_field=config.text_field, max_samples=config.max_samples, user_defined_symbols=user_defined_symbols, ) @staticmethod def train_sentencepiece( jsonl_path: str, output_dir: str = "tokenizers", vocab_size: int = 50000, model_type: str = "unigram", character_coverage: float = 0.9995, unk_id: int = 0, bos_id: int = 1, eos_id: int = 2, pad_id: int = 3, tokenizer_prefix: Optional[str] = None, text_field: str = "text", max_samples: Optional[int] = None, user_defined_symbols: Optional[str] = None, ) -> dict: """ Train a SentencePiece tokenizer from JSONL data. Args: jsonl_path: Path to JSONL file containing text data output_dir: Directory to save tokenizer files vocab_size: Vocabulary size for the tokenizer model_type: Model type (unigram, bpe, char, word) character_coverage: Character coverage for SentencePiece unk_id: Unknown token ID bos_id: Beginning of sentence token ID eos_id: End of sentence token ID pad_id: Padding token ID tokenizer_prefix: Prefix for tokenizer model files (default: model_type) text_field: Field name in JSONL for text data (default: "text") max_samples: Limit training to first N samples (optional) user_defined_symbols: Custom special tokens as comma-separated string (optional) Returns: Dict with paths to generated tokenizer files Raises: ImportError: If SentencePiece is not installed FileNotFoundError: If JSONL file doesn't exist ValueError: If JSONL file is invalid or empty """ try: import sentencepiece as spm except ImportError: raise ImportError( "SentencePiece not installed. Install with: pip install sentencepiece" ) # Validate paths jsonl_path = Path(jsonl_path) if not jsonl_path.exists(): raise FileNotFoundError(f"JSONL file not found: {jsonl_path}") output_dir = Path(output_dir) output_dir.mkdir(parents=True, exist_ok=True) # Set tokenizer prefix if tokenizer_prefix is None: tokenizer_prefix = model_type # Extract text data from JSONL using ChunkManager for large files if max_samples: print(f"📖 Reading JSONL file (up to {max_samples:,} samples): {jsonl_path}") else: print(f"📖 Reading JSONL file: {jsonl_path}") # Use ChunkManager for efficient streaming (with metadata caching enabled) chunk_manager = ChunkManager( jsonl_path, chunk_size_gb=5.0, enable_metadata_cache=True, chunk_cache_dir=".cache/chunks" ) # Write text to temporary file for SentencePiece (use streaming) with tempfile.NamedTemporaryFile( mode='w', suffix='.txt', delete=False, encoding='utf-8' ) as tmp: text_count = 0 # Process chunks one at a time for chunk_num in range(chunk_manager.num_chunks): print(f" - Processing chunk {chunk_num + 1}/{chunk_manager.num_chunks}...") chunk_examples = chunk_manager.read_chunk(chunk_num) for obj in chunk_examples: # Check if we've reached max_samples limit if max_samples and text_count >= max_samples: break # Extract text from specified field or try common field names text = None if text_field in obj and isinstance(obj[text_field], str): text = obj[text_field] else: # Fallback to common field names for field in ['text', 'content', 'data', 'body']: if field in obj and isinstance(obj[field], str): text = obj[field] break if text: # Clean text: remove newlines, extra spaces clean_text = ' '.join(text.split()) tmp.write(clean_text + '\n') text_count += 1 # Break outer loop if max_samples reached if max_samples and text_count >= max_samples: print(f"Reached max_samples limit of {max_samples:,}. Stopping data processing.") break tmp_path = tmp.name if text_count == 0: os.remove(tmp_path) raise ValueError("No valid text data found in JSONL file") sample_info = f"{text_count:,} samples" if max_samples else f"{text_count:,} lines" print(f"✓ Processed {sample_info} with text data from {chunk_manager.num_chunks} chunks") try: # Train SentencePiece model print(f"🔧 Training SentencePiece {model_type} tokenizer...") print(f" - Vocabulary size: {vocab_size}") print(f" - Character coverage: {character_coverage}") if user_defined_symbols: print(f" - Special tokens: {user_defined_symbols}") model_path = output_dir / tokenizer_prefix # Prepare training arguments train_kwargs = { 'input': tmp_path, 'model_prefix': str(model_path), 'vocab_size': vocab_size, 'model_type': model_type, 'character_coverage': character_coverage, 'unk_id': unk_id, 'bos_id': bos_id, 'eos_id': eos_id, 'pad_id': pad_id, # Additional options 'normalization_rule_name': 'identity', 'split_digits': True } # Add user-defined symbols if provided if user_defined_symbols: train_kwargs['user_defined_symbols'] = user_defined_symbols spm.SentencePieceTrainer.train(**train_kwargs) model_file = model_path.with_suffix('.model') vocab_file = model_path.with_suffix('.vocab') if model_file.exists() and vocab_file.exists(): print(f"✅ Tokenizer trained successfully!") print(f" - Model: {model_file}") print(f" - Vocab: {vocab_file}") return { "model_file": str(model_file), "vocab_file": str(vocab_file), "output_dir": str(output_dir), "vocab_size": vocab_size, "model_type": model_type, } else: raise RuntimeError("SentencePiece training didn't produce output files") finally: # Clean up temporary file if os.path.exists(tmp_path): os.remove(tmp_path) @staticmethod def validate_tokenizer(model_path: str) -> bool: """ Validate that a SentencePiece tokenizer file is valid. Args: model_path: Path to .model file Returns: True if valid, False otherwise """ try: import sentencepiece as spm sp = spm.SentencePieceProcessor() sp.Load(model_path) # Try a simple encode/decode test_text = "Hello world" tokens = sp.encode(test_text, out_type=int) decoded = sp.decode(tokens) return len(tokens) > 0 except Exception: return False