StarMist0012's picture
Add files using upload-large-folder tool
e2bfccc verified
"""Main CLI entry point."""
import sys
from pathlib import Path
from typing import Optional
import click
import torch
from taoTrain.config import load_config, load_tokenizer_config, TrainingModeEnum, PretrainConfig, SFTConfig, RLConfig
from taoTrain.utils import set_seed, get_device
from taoTrain.core import BaseModel, create_model, create_datasets
from taoTrain.data import get_dataloader
from taoTrain.training import PretrainTrainer, SFTTrainer, RLTrainer
from taoTrain.benchmarks import BenchmarkRunner
from taoTrain.tokenizers import TokenizerTrainer
@click.group()
def main():
"""TaoTrain: A clean, modular PyTorch LLM training framework."""
pass
@main.command()
@click.option(
"--config",
type=click.Path(exists=True),
required=True,
help="Path to training config file (YAML or JSON)",
)
def pretrain(config: str):
"""Pretrain a language model."""
_train_command(config, TrainingModeEnum.PRETRAIN)
@main.command()
@click.option(
"--config",
type=click.Path(exists=True),
required=True,
help="Path to training config file (YAML or JSON)",
)
def sft(config: str):
"""Supervised fine-tune a language model."""
_train_command(config, TrainingModeEnum.SFT)
@main.command()
@click.option(
"--config",
type=click.Path(exists=True),
required=True,
help="Path to training config file (YAML or JSON)",
)
def rl(config: str):
"""Train with reinforcement learning."""
_train_command(config, TrainingModeEnum.RL)
@main.command()
@click.option(
"--config",
type=click.Path(exists=True),
required=True,
help="Path to tokenizer config file (YAML or JSON)",
)
def train_tokenizer(config: str):
"""Train a SentencePiece tokenizer from a YAML/JSON config file."""
try:
click.echo("🚀 TaoTrain Tokenizer Trainer")
click.echo(f"{'=' * 50}")
# Load tokenizer config
click.echo(f"Loading config from {config}...")
tokenizer_config = load_tokenizer_config(config)
# Train tokenizer from config
result = TokenizerTrainer.train_from_config(tokenizer_config)
# Display results
click.echo(f"\n{'=' * 50}")
click.echo("✅ Tokenizer Training Complete!")
click.echo(f"\n📊 Configuration:")
click.echo(f" - Input file: {tokenizer_config.jsonl_path}")
click.echo(f" - Samples: {tokenizer_config.max_samples or 'all'}")
click.echo(f" - Output dir: {result['output_dir']}")
click.echo(f" - Vocab size: {result['vocab_size']}")
click.echo(f" - Model type: {result['model_type']}")
if tokenizer_config.special_tokens:
click.echo(f" - Special tokens: {tokenizer_config.special_tokens}")
click.echo(f"\n📁 Generated Files:")
click.echo(f" - Model: {result['model_file']}")
click.echo(f" - Vocab: {result['vocab_file']}")
click.echo(f"\n📝 Next Steps:")
click.echo(f" 1. Use this tokenizer in your pretraining config:")
click.echo(f" dataset:")
click.echo(f" local: true")
click.echo(f" jsonl_path: {tokenizer_config.jsonl_path}")
click.echo(f" tokenizer_path: {result['model_file']}")
click.echo(f"")
click.echo(f" 2. Run pretraining with:")
click.echo(f" train pretrain --config your_config.yaml")
except ImportError as e:
click.echo(f"❌ Error: {e}", err=True)
sys.exit(1)
except FileNotFoundError as e:
click.echo(f"❌ File Error: {e}", err=True)
sys.exit(1)
except ValueError as e:
click.echo(f"❌ Validation Error: {e}", err=True)
sys.exit(1)
except Exception as e:
click.echo(f"❌ Unexpected Error: {e}", err=True)
click.echo(f" Please report this issue.", err=True)
sys.exit(1)
@click.command()
@click.option(
"--jsonl-path",
type=click.Path(exists=True),
required=True,
help="Path to JSONL file containing training data",
)
@click.option(
"--output-dir",
type=click.Path(),
default="tokenizers",
help="Directory to save tokenizer files",
)
@click.option(
"--vocab-size",
type=int,
default=50000,
help="Vocabulary size for the tokenizer",
)
@click.option(
"--model-type",
type=click.Choice(["unigram", "bpe", "char", "word"]),
default="unigram",
help="SentencePiece model type",
)
@click.option(
"--character-coverage",
type=float,
default=0.9995,
help="Character coverage for SentencePiece",
)
@click.option(
"--tokenizer-prefix",
type=str,
default=None,
help="Prefix for tokenizer output files (default: model_type)",
)
def train_tokenizer_command(
jsonl_path: str,
output_dir: str,
vocab_size: int,
model_type: str,
character_coverage: float,
tokenizer_prefix: Optional[str],
):
"""Train a SentencePiece tokenizer from JSONL data."""
try:
click.echo("🚀 TaoTrain Tokenizer Trainer")
click.echo(f"{'=' * 50}")
# Train tokenizer
result = TokenizerTrainer.train_sentencepiece(
jsonl_path=jsonl_path,
output_dir=output_dir,
vocab_size=vocab_size,
model_type=model_type,
character_coverage=character_coverage,
tokenizer_prefix=tokenizer_prefix,
)
# Display results
click.echo(f"\n{'=' * 50}")
click.echo("✅ Tokenizer Training Complete!")
click.echo(f"\n📊 Configuration:")
click.echo(f" - Input file: {jsonl_path}")
click.echo(f" - Output dir: {result['output_dir']}")
click.echo(f" - Vocab size: {result['vocab_size']}")
click.echo(f" - Model type: {result['model_type']}")
click.echo(f"\n📁 Generated Files:")
click.echo(f" - Model: {result['model_file']}")
click.echo(f" - Vocab: {result['vocab_file']}")
click.echo(f"\n📝 Next Steps:")
click.echo(f" 1. Use this tokenizer in your pretraining config:")
click.echo(f" dataset:")
click.echo(f" local: true")
click.echo(f" jsonl_path: {jsonl_path}")
click.echo(f" tokenizer_path: {result['model_file']}")
click.echo(f"")
click.echo(f" 2. Run pretraining with:")
click.echo(f" train pretrain --config your_config.yaml")
except ImportError as e:
click.echo(f"❌ Error: {e}", err=True)
sys.exit(1)
except FileNotFoundError as e:
click.echo(f"❌ File Error: {e}", err=True)
sys.exit(1)
except ValueError as e:
click.echo(f"❌ Validation Error: {e}", err=True)
sys.exit(1)
except Exception as e:
click.echo(f"❌ Unexpected Error: {e}", err=True)
click.echo(f" Please report this issue.", err=True)
sys.exit(1)
# Keep legacy CLI command as train-tokenizer-legacy for backward compatibility
main.add_command(train_tokenizer_command, name="train-tokenizer-legacy")
def _train_command(config_path: str, mode: TrainingModeEnum):
"""Internal training command."""
try:
# Load config
click.echo(f"Loading config from {config_path}...")
train_config = load_config(config_path, mode)
# Set seed
set_seed(train_config.seed)
# Get device
device = get_device(train_config.device)
click.echo(f"Using device: {device}")
# Create model
click.echo("Creating model...")
model = create_model(train_config, device)
total_params, trainable_params = _count_params(model)
click.echo(f" - Total parameters: {total_params:,}")
click.echo(f" - Trainable parameters: {trainable_params:,}")
# Load pretrained checkpoint if provided (for SFT/RL)
if train_config.checkpoint_path:
click.echo(f"Loading pretrained checkpoint from {train_config.checkpoint_path}...")
from taoTrain.checkpointing.checkpoint import CheckpointManager
checkpoint_manager = CheckpointManager(train_config.checkpoint_dir)
checkpoint = checkpoint_manager.load(train_config.checkpoint_path, device=device)
# CheckpointManager.load() normalizes format and ensures 'model_state' key exists
if "model_state" in checkpoint:
model.load_state_dict(checkpoint["model_state"], strict=False)
click.echo(" ✓ Checkpoint loaded successfully")
else:
raise KeyError(f"Invalid checkpoint format: 'model_state' key not found. "
f"Available keys: {list(checkpoint.keys())}")
# Create datasets
click.echo("Loading datasets...")
train_dataset, val_dataset = create_datasets(train_config)
click.echo(f" - Train samples: {len(train_dataset)}")
if val_dataset:
click.echo(f" - Val samples: {len(val_dataset)}")
# Select trainer
if mode == TrainingModeEnum.PRETRAIN:
trainer_class = PretrainTrainer
elif mode == TrainingModeEnum.SFT:
trainer_class = SFTTrainer
elif mode == TrainingModeEnum.RL:
trainer_class = RLTrainer
else:
raise ValueError(f"Unknown training mode: {mode}")
# Create trainer
click.echo("Setting up trainer...")
trainer = trainer_class(
model=model,
train_dataset=train_dataset,
val_dataset=val_dataset,
config=train_config,
device=device,
)
# Training loop
click.echo("\nStarting training...\n")
for epoch in range(train_config.num_epochs):
if train_config.max_steps and trainer.global_step >= train_config.max_steps:
break
epoch_metrics = trainer.train_epoch()
click.echo(f"\nEpoch {epoch + 1} complete")
click.echo(f" - Loss: {epoch_metrics.get('loss', 'N/A')}")
click.echo(f" - Learning rate: {epoch_metrics.get('lr', 'N/A')}")
# Final checkpoint
final_path = Path(train_config.checkpoint_dir) / "final_model.pt"
trainer.save_checkpoint(final_path)
click.echo(f"\nTraining complete! Final model saved to {final_path}")
# Log finish
trainer.logger.finish()
except Exception as e:
click.echo(f"Error during training: {e}", err=True)
sys.exit(1)
@main.command()
@click.option(
"--model",
type=click.Path(exists=True),
required=True,
help="Path to model checkpoint",
)
@click.option(
"--benchmark-type",
type=click.Choice(["all", "perplexity", "throughput", "memory"]),
default="all",
help="Type of benchmark to run",
)
@click.option(
"--batch-size",
type=int,
default=32,
help="Batch size for benchmarking",
)
@click.option(
"--seq-length",
type=int,
default=1024,
help="Sequence length for benchmarking",
)
def benchmark(model: str, benchmark_type: str, batch_size: int, seq_length: int):
"""Benchmark a trained model."""
try:
click.echo(f"Loading model from {model}...")
device = get_device("cuda")
runner = BenchmarkRunner.load_from_checkpoint(model, device=device)
click.echo("Running benchmarks...\n")
if benchmark_type == "throughput" or benchmark_type == "all":
click.echo("Throughput benchmark:")
results = runner.benchmark_throughput(batch_size, seq_length)
for key, val in results.items():
click.echo(f" {key}: {val:.2f}")
if benchmark_type == "memory" or benchmark_type == "all":
click.echo("\nMemory benchmark:")
results = runner.benchmark_memory()
for key, val in results.items():
click.echo(f" {key}: {val:.2f}")
click.echo("\nBenchmarking complete!")
except Exception as e:
click.echo(f"Error during benchmarking: {e}", err=True)
sys.exit(1)
@main.command()
@click.option("--repo", type=str, default=".aim", help="AimStack repository path")
def view_logs(repo: str):
"""View training logs with AimStack."""
try:
import subprocess
click.echo(f"Opening AimStack dashboard for repo: {repo}")
subprocess.run(["aim", "up", "--repo", repo])
except FileNotFoundError:
click.echo("Error: 'aim' command not found. Install with: pip install aim", err=True)
sys.exit(1)
def _count_params(model: BaseModel) -> tuple[int, int]:
"""Count model parameters."""
total = sum(p.numel() for p in model.parameters())
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
return total, trainable
if __name__ == "__main__":
main()