"""Main CLI entry point.""" import sys from pathlib import Path from typing import Optional import click import torch from taoTrain.config import load_config, load_tokenizer_config, TrainingModeEnum, PretrainConfig, SFTConfig, RLConfig from taoTrain.utils import set_seed, get_device from taoTrain.core import BaseModel, create_model, create_datasets from taoTrain.data import get_dataloader from taoTrain.training import PretrainTrainer, SFTTrainer, RLTrainer from taoTrain.benchmarks import BenchmarkRunner from taoTrain.tokenizers import TokenizerTrainer @click.group() def main(): """TaoTrain: A clean, modular PyTorch LLM training framework.""" pass @main.command() @click.option( "--config", type=click.Path(exists=True), required=True, help="Path to training config file (YAML or JSON)", ) def pretrain(config: str): """Pretrain a language model.""" _train_command(config, TrainingModeEnum.PRETRAIN) @main.command() @click.option( "--config", type=click.Path(exists=True), required=True, help="Path to training config file (YAML or JSON)", ) def sft(config: str): """Supervised fine-tune a language model.""" _train_command(config, TrainingModeEnum.SFT) @main.command() @click.option( "--config", type=click.Path(exists=True), required=True, help="Path to training config file (YAML or JSON)", ) def rl(config: str): """Train with reinforcement learning.""" _train_command(config, TrainingModeEnum.RL) @main.command() @click.option( "--config", type=click.Path(exists=True), required=True, help="Path to tokenizer config file (YAML or JSON)", ) def train_tokenizer(config: str): """Train a SentencePiece tokenizer from a YAML/JSON config file.""" try: click.echo("šŸš€ TaoTrain Tokenizer Trainer") click.echo(f"{'=' * 50}") # Load tokenizer config click.echo(f"Loading config from {config}...") tokenizer_config = load_tokenizer_config(config) # Train tokenizer from config result = TokenizerTrainer.train_from_config(tokenizer_config) # Display results click.echo(f"\n{'=' * 50}") click.echo("āœ… Tokenizer Training Complete!") click.echo(f"\nšŸ“Š Configuration:") click.echo(f" - Input file: {tokenizer_config.jsonl_path}") click.echo(f" - Samples: {tokenizer_config.max_samples or 'all'}") click.echo(f" - Output dir: {result['output_dir']}") click.echo(f" - Vocab size: {result['vocab_size']}") click.echo(f" - Model type: {result['model_type']}") if tokenizer_config.special_tokens: click.echo(f" - Special tokens: {tokenizer_config.special_tokens}") click.echo(f"\nšŸ“ Generated Files:") click.echo(f" - Model: {result['model_file']}") click.echo(f" - Vocab: {result['vocab_file']}") click.echo(f"\nšŸ“ Next Steps:") click.echo(f" 1. Use this tokenizer in your pretraining config:") click.echo(f" dataset:") click.echo(f" local: true") click.echo(f" jsonl_path: {tokenizer_config.jsonl_path}") click.echo(f" tokenizer_path: {result['model_file']}") click.echo(f"") click.echo(f" 2. Run pretraining with:") click.echo(f" train pretrain --config your_config.yaml") except ImportError as e: click.echo(f"āŒ Error: {e}", err=True) sys.exit(1) except FileNotFoundError as e: click.echo(f"āŒ File Error: {e}", err=True) sys.exit(1) except ValueError as e: click.echo(f"āŒ Validation Error: {e}", err=True) sys.exit(1) except Exception as e: click.echo(f"āŒ Unexpected Error: {e}", err=True) click.echo(f" Please report this issue.", err=True) sys.exit(1) @click.command() @click.option( "--jsonl-path", type=click.Path(exists=True), required=True, help="Path to JSONL file containing training data", ) @click.option( "--output-dir", type=click.Path(), default="tokenizers", help="Directory to save tokenizer files", ) @click.option( "--vocab-size", type=int, default=50000, help="Vocabulary size for the tokenizer", ) @click.option( "--model-type", type=click.Choice(["unigram", "bpe", "char", "word"]), default="unigram", help="SentencePiece model type", ) @click.option( "--character-coverage", type=float, default=0.9995, help="Character coverage for SentencePiece", ) @click.option( "--tokenizer-prefix", type=str, default=None, help="Prefix for tokenizer output files (default: model_type)", ) def train_tokenizer_command( jsonl_path: str, output_dir: str, vocab_size: int, model_type: str, character_coverage: float, tokenizer_prefix: Optional[str], ): """Train a SentencePiece tokenizer from JSONL data.""" try: click.echo("šŸš€ TaoTrain Tokenizer Trainer") click.echo(f"{'=' * 50}") # Train tokenizer result = TokenizerTrainer.train_sentencepiece( jsonl_path=jsonl_path, output_dir=output_dir, vocab_size=vocab_size, model_type=model_type, character_coverage=character_coverage, tokenizer_prefix=tokenizer_prefix, ) # Display results click.echo(f"\n{'=' * 50}") click.echo("āœ… Tokenizer Training Complete!") click.echo(f"\nšŸ“Š Configuration:") click.echo(f" - Input file: {jsonl_path}") click.echo(f" - Output dir: {result['output_dir']}") click.echo(f" - Vocab size: {result['vocab_size']}") click.echo(f" - Model type: {result['model_type']}") click.echo(f"\nšŸ“ Generated Files:") click.echo(f" - Model: {result['model_file']}") click.echo(f" - Vocab: {result['vocab_file']}") click.echo(f"\nšŸ“ Next Steps:") click.echo(f" 1. Use this tokenizer in your pretraining config:") click.echo(f" dataset:") click.echo(f" local: true") click.echo(f" jsonl_path: {jsonl_path}") click.echo(f" tokenizer_path: {result['model_file']}") click.echo(f"") click.echo(f" 2. Run pretraining with:") click.echo(f" train pretrain --config your_config.yaml") except ImportError as e: click.echo(f"āŒ Error: {e}", err=True) sys.exit(1) except FileNotFoundError as e: click.echo(f"āŒ File Error: {e}", err=True) sys.exit(1) except ValueError as e: click.echo(f"āŒ Validation Error: {e}", err=True) sys.exit(1) except Exception as e: click.echo(f"āŒ Unexpected Error: {e}", err=True) click.echo(f" Please report this issue.", err=True) sys.exit(1) # Keep legacy CLI command as train-tokenizer-legacy for backward compatibility main.add_command(train_tokenizer_command, name="train-tokenizer-legacy") def _train_command(config_path: str, mode: TrainingModeEnum): """Internal training command.""" try: # Load config click.echo(f"Loading config from {config_path}...") train_config = load_config(config_path, mode) # Set seed set_seed(train_config.seed) # Get device device = get_device(train_config.device) click.echo(f"Using device: {device}") # Create model click.echo("Creating model...") model = create_model(train_config, device) total_params, trainable_params = _count_params(model) click.echo(f" - Total parameters: {total_params:,}") click.echo(f" - Trainable parameters: {trainable_params:,}") # Load pretrained checkpoint if provided (for SFT/RL) if train_config.checkpoint_path: click.echo(f"Loading pretrained checkpoint from {train_config.checkpoint_path}...") from taoTrain.checkpointing.checkpoint import CheckpointManager checkpoint_manager = CheckpointManager(train_config.checkpoint_dir) checkpoint = checkpoint_manager.load(train_config.checkpoint_path, device=device) # CheckpointManager.load() normalizes format and ensures 'model_state' key exists if "model_state" in checkpoint: model.load_state_dict(checkpoint["model_state"], strict=False) click.echo(" āœ“ Checkpoint loaded successfully") else: raise KeyError(f"Invalid checkpoint format: 'model_state' key not found. " f"Available keys: {list(checkpoint.keys())}") # Create datasets click.echo("Loading datasets...") train_dataset, val_dataset = create_datasets(train_config) click.echo(f" - Train samples: {len(train_dataset)}") if val_dataset: click.echo(f" - Val samples: {len(val_dataset)}") # Select trainer if mode == TrainingModeEnum.PRETRAIN: trainer_class = PretrainTrainer elif mode == TrainingModeEnum.SFT: trainer_class = SFTTrainer elif mode == TrainingModeEnum.RL: trainer_class = RLTrainer else: raise ValueError(f"Unknown training mode: {mode}") # Create trainer click.echo("Setting up trainer...") trainer = trainer_class( model=model, train_dataset=train_dataset, val_dataset=val_dataset, config=train_config, device=device, ) # Training loop click.echo("\nStarting training...\n") for epoch in range(train_config.num_epochs): if train_config.max_steps and trainer.global_step >= train_config.max_steps: break epoch_metrics = trainer.train_epoch() click.echo(f"\nEpoch {epoch + 1} complete") click.echo(f" - Loss: {epoch_metrics.get('loss', 'N/A')}") click.echo(f" - Learning rate: {epoch_metrics.get('lr', 'N/A')}") # Final checkpoint final_path = Path(train_config.checkpoint_dir) / "final_model.pt" trainer.save_checkpoint(final_path) click.echo(f"\nTraining complete! Final model saved to {final_path}") # Log finish trainer.logger.finish() except Exception as e: click.echo(f"Error during training: {e}", err=True) sys.exit(1) @main.command() @click.option( "--model", type=click.Path(exists=True), required=True, help="Path to model checkpoint", ) @click.option( "--benchmark-type", type=click.Choice(["all", "perplexity", "throughput", "memory"]), default="all", help="Type of benchmark to run", ) @click.option( "--batch-size", type=int, default=32, help="Batch size for benchmarking", ) @click.option( "--seq-length", type=int, default=1024, help="Sequence length for benchmarking", ) def benchmark(model: str, benchmark_type: str, batch_size: int, seq_length: int): """Benchmark a trained model.""" try: click.echo(f"Loading model from {model}...") device = get_device("cuda") runner = BenchmarkRunner.load_from_checkpoint(model, device=device) click.echo("Running benchmarks...\n") if benchmark_type == "throughput" or benchmark_type == "all": click.echo("Throughput benchmark:") results = runner.benchmark_throughput(batch_size, seq_length) for key, val in results.items(): click.echo(f" {key}: {val:.2f}") if benchmark_type == "memory" or benchmark_type == "all": click.echo("\nMemory benchmark:") results = runner.benchmark_memory() for key, val in results.items(): click.echo(f" {key}: {val:.2f}") click.echo("\nBenchmarking complete!") except Exception as e: click.echo(f"Error during benchmarking: {e}", err=True) sys.exit(1) @main.command() @click.option("--repo", type=str, default=".aim", help="AimStack repository path") def view_logs(repo: str): """View training logs with AimStack.""" try: import subprocess click.echo(f"Opening AimStack dashboard for repo: {repo}") subprocess.run(["aim", "up", "--repo", repo]) except FileNotFoundError: click.echo("Error: 'aim' command not found. Install with: pip install aim", err=True) sys.exit(1) def _count_params(model: BaseModel) -> tuple[int, int]: """Count model parameters.""" total = sum(p.numel() for p in model.parameters()) trainable = sum(p.numel() for p in model.parameters() if p.requires_grad) return total, trainable if __name__ == "__main__": main()