| |
| """ |
| Command-line interface for training BBPE tokenizers. |
| |
| Usage: |
| python train_tokenizer.py --data_dir ./data --vocab_size 30000 --model_name EthioBBPE |
| """ |
|
|
| import argparse |
| import sys |
| from pathlib import Path |
|
|
| |
| sys.path.insert(0, str(Path(__file__).parent)) |
|
|
| from bbpe_trainer import EthioBBPETrainer, BBPEConfig |
|
|
|
|
| def parse_args(): |
| """Parse command-line arguments.""" |
| parser = argparse.ArgumentParser( |
| description="Train a Byte-Level BPE (BBPE) tokenizer", |
| formatter_class=argparse.ArgumentDefaultsHelpFormatter, |
| ) |
| |
| |
| parser.add_argument( |
| "--data_dir", |
| type=str, |
| default="data", |
| help="Directory containing training text files (.txt, .json, .jsonl)", |
| ) |
| |
| parser.add_argument( |
| "--files", |
| type=str, |
| nargs="+", |
| default=None, |
| help="Specific files to train on (overrides data_dir)", |
| ) |
| |
| |
| parser.add_argument( |
| "--vocab_size", |
| type=int, |
| default=30000, |
| help="Target vocabulary size", |
| ) |
| |
| parser.add_argument( |
| "--min_frequency", |
| type=int, |
| default=2, |
| help="Minimum frequency for tokens to be included in vocabulary", |
| ) |
| |
| parser.add_argument( |
| "--special_tokens", |
| type=str, |
| nargs="+", |
| default=["<pad>", "<unk>", "<s>", "</s>", "<mask>"], |
| help="Special tokens to add to the vocabulary", |
| ) |
| |
| |
| parser.add_argument( |
| "--lowercase", |
| action="store_true", |
| help="Convert text to lowercase before tokenization", |
| ) |
| |
| parser.add_argument( |
| "--no_prefix_space", |
| action="store_true", |
| help="Disable adding prefix space (default: add prefix space)", |
| ) |
| |
| parser.add_argument( |
| "--show_progress", |
| action="store_true", |
| default=True, |
| help="Show training progress bar", |
| ) |
| |
| parser.add_argument( |
| "--no_progress", |
| action="store_false", |
| dest="show_progress", |
| help="Hide training progress bar", |
| ) |
| |
| |
| parser.add_argument( |
| "--model_save_dir", |
| type=str, |
| default="models", |
| help="Directory to save the trained tokenizer", |
| ) |
| |
| parser.add_argument( |
| "--model_name", |
| type=str, |
| default="EthioBBPE", |
| help="Name for the saved tokenizer model", |
| ) |
| |
| |
| parser.add_argument( |
| "--config_file", |
| type=str, |
| default=None, |
| help="Path to JSON config file (overrides other arguments)", |
| ) |
| |
| parser.add_argument( |
| "--save_config", |
| type=str, |
| default=None, |
| help="Path to save the configuration JSON file", |
| ) |
| |
| |
| parser.add_argument( |
| "--use_checkpoint", |
| action="store_true", |
| default=True, |
| help="Enable checkpointing during training", |
| ) |
| |
| parser.add_argument( |
| "--no_checkpoint", |
| action="store_false", |
| dest="use_checkpoint", |
| help="Disable checkpointing", |
| ) |
| |
| parser.add_argument( |
| "--checkpoint_dir", |
| type=str, |
| default="./models/checkpoints", |
| help="Directory to save checkpoints", |
| ) |
| |
| parser.add_argument( |
| "--save_compressed", |
| action="store_true", |
| default=True, |
| help="Save tokenizer files in compressed format (.gz)", |
| ) |
| |
| parser.add_argument( |
| "--no_compression", |
| action="store_false", |
| dest="save_compressed", |
| help="Disable compression", |
| ) |
| |
| return parser.parse_args() |
|
|
|
|
| def main(): |
| """Main entry point for CLI training.""" |
| args = parse_args() |
| |
| |
| if args.config_file: |
| print(f"Loading configuration from {args.config_file}") |
| config = BBPEConfig.load(args.config_file) |
| else: |
| |
| config = BBPEConfig( |
| vocab_size=args.vocab_size, |
| min_frequency=args.min_frequency, |
| special_tokens=args.special_tokens, |
| lowercase=args.lowercase, |
| show_progress=args.show_progress, |
| data_dir=args.data_dir, |
| model_save_dir=args.model_save_dir, |
| model_name=args.model_name, |
| use_checkpoint=args.use_checkpoint, |
| checkpoint_dir=args.checkpoint_dir, |
| save_compressed=args.save_compressed, |
| ) |
| |
| |
| if args.save_config: |
| config.save(args.save_config) |
| print(f"Configuration saved to {args.save_config}") |
| |
| |
| trainer = EthioBBPETrainer(config) |
| |
| |
| if args.files: |
| print(f"Using specified files: {args.files}") |
| files = args.files |
| else: |
| files = None |
| |
| |
| try: |
| trainer.train(files=files) |
| except FileNotFoundError as e: |
| print(f"\nError: {e}") |
| print("\nTo fix this:") |
| print(f" 1. Add your training data to the '{args.data_dir}' directory") |
| print(" 2. Supported formats: .txt, .json, .jsonl") |
| print(" 3. Or specify files directly with --files flag") |
| sys.exit(1) |
| |
| |
| save_path = trainer.save() |
| |
| |
| print("\n" + "="*60) |
| print("TESTING TOKENIZER") |
| print("="*60) |
| |
| test_texts = [ |
| "Hello, world!", |
| "This is a test of the EthioBBPE tokenizer.", |
| "Special characters: @#$%^&*()", |
| "Numbers: 12345 and words mixed together.", |
| ] |
| |
| for text in test_texts: |
| encoded = trainer.encode(text) |
| tokens = trainer.tokenize(text) |
| decoded = trainer.decode(encoded) |
| |
| print(f"\nInput: {text}") |
| print(f"Tokens: {tokens}") |
| print(f"IDs: {encoded[:20]}{'...' if len(encoded) > 20 else ''}") |
| print(f"Decoded: {decoded}") |
| |
| print("\n" + "="*60) |
| print(f"Tokenizer training complete!") |
| print(f"Model saved to: {save_path}") |
| if args.save_compressed: |
| print(f"Compressed files also saved (look for .gz files)") |
| if args.use_checkpoint: |
| print(f"Checkpoints saved to: {args.checkpoint_dir}") |
| print("="*60) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|