File size: 6,562 Bytes
d681ea8 bbb716d d681ea8 bbb716d d681ea8 bbb716d d681ea8 bbb716d 80f2516 bbb716d d681ea8 bbb716d d681ea8 bbb716d d681ea8 80f2516 d681ea8 bbb716d d681ea8 bbb716d 80f2516 bbb716d 80f2516 d681ea8 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 | #!/usr/bin/env python3
"""
Command-line interface for training BBPE tokenizers.
Usage:
python train_tokenizer.py --data_dir ./data --vocab_size 30000 --model_name EthioBBPE
"""
import argparse
import sys
from pathlib import Path
# Add parent directory to path for imports
sys.path.insert(0, str(Path(__file__).parent))
from bbpe_trainer import EthioBBPETrainer, BBPEConfig
def parse_args():
"""Parse command-line arguments."""
parser = argparse.ArgumentParser(
description="Train a Byte-Level BPE (BBPE) tokenizer",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
# Data arguments
parser.add_argument(
"--data_dir",
type=str,
default="data",
help="Directory containing training text files (.txt, .json, .jsonl)",
)
parser.add_argument(
"--files",
type=str,
nargs="+",
default=None,
help="Specific files to train on (overrides data_dir)",
)
# Model arguments
parser.add_argument(
"--vocab_size",
type=int,
default=30000,
help="Target vocabulary size",
)
parser.add_argument(
"--min_frequency",
type=int,
default=2,
help="Minimum frequency for tokens to be included in vocabulary",
)
parser.add_argument(
"--special_tokens",
type=str,
nargs="+",
default=["<pad>", "<unk>", "<s>", "</s>", "<mask>"],
help="Special tokens to add to the vocabulary",
)
# Training options
parser.add_argument(
"--lowercase",
action="store_true",
help="Convert text to lowercase before tokenization",
)
parser.add_argument(
"--no_prefix_space",
action="store_true",
help="Disable adding prefix space (default: add prefix space)",
)
parser.add_argument(
"--show_progress",
action="store_true",
default=True,
help="Show training progress bar",
)
parser.add_argument(
"--no_progress",
action="store_false",
dest="show_progress",
help="Hide training progress bar",
)
# Output arguments
parser.add_argument(
"--model_save_dir",
type=str,
default="models",
help="Directory to save the trained tokenizer",
)
parser.add_argument(
"--model_name",
type=str,
default="EthioBBPE",
help="Name for the saved tokenizer model",
)
# Config file arguments
parser.add_argument(
"--config_file",
type=str,
default=None,
help="Path to JSON config file (overrides other arguments)",
)
parser.add_argument(
"--save_config",
type=str,
default=None,
help="Path to save the configuration JSON file",
)
# Advanced production features
parser.add_argument(
"--use_checkpoint",
action="store_true",
default=True,
help="Enable checkpointing during training",
)
parser.add_argument(
"--no_checkpoint",
action="store_false",
dest="use_checkpoint",
help="Disable checkpointing",
)
parser.add_argument(
"--checkpoint_dir",
type=str,
default="./models/checkpoints",
help="Directory to save checkpoints",
)
parser.add_argument(
"--save_compressed",
action="store_true",
default=True,
help="Save tokenizer files in compressed format (.gz)",
)
parser.add_argument(
"--no_compression",
action="store_false",
dest="save_compressed",
help="Disable compression",
)
return parser.parse_args()
def main():
"""Main entry point for CLI training."""
args = parse_args()
# Load config from file if provided
if args.config_file:
print(f"Loading configuration from {args.config_file}")
config = BBPEConfig.load(args.config_file)
else:
# Create config from arguments
config = BBPEConfig(
vocab_size=args.vocab_size,
min_frequency=args.min_frequency,
special_tokens=args.special_tokens,
lowercase=args.lowercase,
show_progress=args.show_progress,
data_dir=args.data_dir,
model_save_dir=args.model_save_dir,
model_name=args.model_name,
use_checkpoint=args.use_checkpoint,
checkpoint_dir=args.checkpoint_dir,
save_compressed=args.save_compressed,
)
# Save config if requested
if args.save_config:
config.save(args.save_config)
print(f"Configuration saved to {args.save_config}")
# Initialize trainer
trainer = EthioBBPETrainer(config)
# Get training files
if args.files:
print(f"Using specified files: {args.files}")
files = args.files
else:
files = None # Will use files from data_dir
# Train the tokenizer
try:
trainer.train(files=files)
except FileNotFoundError as e:
print(f"\nError: {e}")
print("\nTo fix this:")
print(f" 1. Add your training data to the '{args.data_dir}' directory")
print(" 2. Supported formats: .txt, .json, .jsonl")
print(" 3. Or specify files directly with --files flag")
sys.exit(1)
# Save the tokenizer
save_path = trainer.save()
# Test the tokenizer
print("\n" + "="*60)
print("TESTING TOKENIZER")
print("="*60)
test_texts = [
"Hello, world!",
"This is a test of the EthioBBPE tokenizer.",
"Special characters: @#$%^&*()",
"Numbers: 12345 and words mixed together.",
]
for text in test_texts:
encoded = trainer.encode(text)
tokens = trainer.tokenize(text)
decoded = trainer.decode(encoded)
print(f"\nInput: {text}")
print(f"Tokens: {tokens}")
print(f"IDs: {encoded[:20]}{'...' if len(encoded) > 20 else ''}")
print(f"Decoded: {decoded}")
print("\n" + "="*60)
print(f"Tokenizer training complete!")
print(f"Model saved to: {save_path}")
if args.save_compressed:
print(f"Compressed files also saved (look for .gz files)")
if args.use_checkpoint:
print(f"Checkpoints saved to: {args.checkpoint_dir}")
print("="*60)
if __name__ == "__main__":
main()
|