StarMist0012's picture
Add files using upload-large-folder tool
3270dae verified
"""Tokenizer training utilities."""
import os
import json
import tempfile
from pathlib import Path
from typing import Optional, Dict, Any
from taoTrain.data.chunk_manager import ChunkManager
class TokenizerTrainer:
"""Train SentencePiece tokenizers from JSONL data."""
@staticmethod
def train_from_config(config: "TokenizerConfig") -> dict: # type: ignore
"""
Train a tokenizer from a TokenizerConfig object.
Args:
config: TokenizerConfig instance
Returns:
Dict with paths to generated tokenizer files
"""
# Build special tokens string for SentencePiece if provided
user_defined_symbols = None
if config.special_tokens:
# Sort by ID and format as comma-separated tokens
sorted_tokens = sorted(config.special_tokens.items(), key=lambda x: x[1])
user_defined_symbols = ','.join([token for token, _ in sorted_tokens])
return TokenizerTrainer.train_sentencepiece(
jsonl_path=config.jsonl_path,
output_dir=config.output_dir,
vocab_size=config.vocab_size,
model_type=config.model_type,
character_coverage=config.character_coverage,
unk_id=config.unk_id,
bos_id=config.bos_id,
eos_id=config.eos_id,
pad_id=config.pad_id,
tokenizer_prefix=config.tokenizer_prefix,
text_field=config.text_field,
max_samples=config.max_samples,
user_defined_symbols=user_defined_symbols,
)
@staticmethod
def train_sentencepiece(
jsonl_path: str,
output_dir: str = "tokenizers",
vocab_size: int = 50000,
model_type: str = "unigram",
character_coverage: float = 0.9995,
unk_id: int = 0,
bos_id: int = 1,
eos_id: int = 2,
pad_id: int = 3,
tokenizer_prefix: Optional[str] = None,
text_field: str = "text",
max_samples: Optional[int] = None,
user_defined_symbols: Optional[str] = None,
) -> dict:
"""
Train a SentencePiece tokenizer from JSONL data.
Args:
jsonl_path: Path to JSONL file containing text data
output_dir: Directory to save tokenizer files
vocab_size: Vocabulary size for the tokenizer
model_type: Model type (unigram, bpe, char, word)
character_coverage: Character coverage for SentencePiece
unk_id: Unknown token ID
bos_id: Beginning of sentence token ID
eos_id: End of sentence token ID
pad_id: Padding token ID
tokenizer_prefix: Prefix for tokenizer model files (default: model_type)
text_field: Field name in JSONL for text data (default: "text")
max_samples: Limit training to first N samples (optional)
user_defined_symbols: Custom special tokens as comma-separated string (optional)
Returns:
Dict with paths to generated tokenizer files
Raises:
ImportError: If SentencePiece is not installed
FileNotFoundError: If JSONL file doesn't exist
ValueError: If JSONL file is invalid or empty
"""
try:
import sentencepiece as spm
except ImportError:
raise ImportError(
"SentencePiece not installed. Install with: pip install sentencepiece"
)
# Validate paths
jsonl_path = Path(jsonl_path)
if not jsonl_path.exists():
raise FileNotFoundError(f"JSONL file not found: {jsonl_path}")
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
# Set tokenizer prefix
if tokenizer_prefix is None:
tokenizer_prefix = model_type
# Extract text data from JSONL using ChunkManager for large files
if max_samples:
print(f"๐Ÿ“– Reading JSONL file (up to {max_samples:,} samples): {jsonl_path}")
else:
print(f"๐Ÿ“– Reading JSONL file: {jsonl_path}")
# Use ChunkManager for efficient streaming (with metadata caching enabled)
chunk_manager = ChunkManager(
jsonl_path,
chunk_size_gb=5.0,
enable_metadata_cache=True,
chunk_cache_dir=".cache/chunks"
)
# Write text to temporary file for SentencePiece (use streaming)
with tempfile.NamedTemporaryFile(
mode='w',
suffix='.txt',
delete=False,
encoding='utf-8'
) as tmp:
text_count = 0
# Process chunks one at a time
for chunk_num in range(chunk_manager.num_chunks):
print(f" - Processing chunk {chunk_num + 1}/{chunk_manager.num_chunks}...")
chunk_examples = chunk_manager.read_chunk(chunk_num)
for obj in chunk_examples:
# Check if we've reached max_samples limit
if max_samples and text_count >= max_samples:
break
# Extract text from specified field or try common field names
text = None
if text_field in obj and isinstance(obj[text_field], str):
text = obj[text_field]
else:
# Fallback to common field names
for field in ['text', 'content', 'data', 'body']:
if field in obj and isinstance(obj[field], str):
text = obj[field]
break
if text:
# Clean text: remove newlines, extra spaces
clean_text = ' '.join(text.split())
tmp.write(clean_text + '\n')
text_count += 1
# Break outer loop if max_samples reached
if max_samples and text_count >= max_samples:
print(f"Reached max_samples limit of {max_samples:,}. Stopping data processing.")
break
tmp_path = tmp.name
if text_count == 0:
os.remove(tmp_path)
raise ValueError("No valid text data found in JSONL file")
sample_info = f"{text_count:,} samples" if max_samples else f"{text_count:,} lines"
print(f"โœ“ Processed {sample_info} with text data from {chunk_manager.num_chunks} chunks")
try:
# Train SentencePiece model
print(f"๐Ÿ”ง Training SentencePiece {model_type} tokenizer...")
print(f" - Vocabulary size: {vocab_size}")
print(f" - Character coverage: {character_coverage}")
if user_defined_symbols:
print(f" - Special tokens: {user_defined_symbols}")
model_path = output_dir / tokenizer_prefix
# Prepare training arguments
train_kwargs = {
'input': tmp_path,
'model_prefix': str(model_path),
'vocab_size': vocab_size,
'model_type': model_type,
'character_coverage': character_coverage,
'unk_id': unk_id,
'bos_id': bos_id,
'eos_id': eos_id,
'pad_id': pad_id,
# Additional options
'normalization_rule_name': 'identity',
'split_digits': True
}
# Add user-defined symbols if provided
if user_defined_symbols:
train_kwargs['user_defined_symbols'] = user_defined_symbols
spm.SentencePieceTrainer.train(**train_kwargs)
model_file = model_path.with_suffix('.model')
vocab_file = model_path.with_suffix('.vocab')
if model_file.exists() and vocab_file.exists():
print(f"โœ… Tokenizer trained successfully!")
print(f" - Model: {model_file}")
print(f" - Vocab: {vocab_file}")
return {
"model_file": str(model_file),
"vocab_file": str(vocab_file),
"output_dir": str(output_dir),
"vocab_size": vocab_size,
"model_type": model_type,
}
else:
raise RuntimeError("SentencePiece training didn't produce output files")
finally:
# Clean up temporary file
if os.path.exists(tmp_path):
os.remove(tmp_path)
@staticmethod
def validate_tokenizer(model_path: str) -> bool:
"""
Validate that a SentencePiece tokenizer file is valid.
Args:
model_path: Path to .model file
Returns:
True if valid, False otherwise
"""
try:
import sentencepiece as spm
sp = spm.SentencePieceProcessor()
sp.Load(model_path)
# Try a simple encode/decode
test_text = "Hello world"
tokens = sp.encode(test_text, out_type=int)
decoded = sp.decode(tokens)
return len(tokens) > 0
except Exception:
return False