""" Enhanced CLI for KerdosAI with rich output and better UX. """ import typer from typing import Optional from pathlib import Path from rich.console import Console from rich.table import Table from rich.progress import Progress, SpinnerColumn, TextColumn from rich.panel import Panel from rich import print as rprint import logging from config import load_config, KerdosConfig from exceptions import KerdosError # Initialize Typer app and Rich console app = typer.Typer( name="kerdosai", help="KerdosAI - Universal LLM Training Agent", add_completion=False ) console = Console() # Configure logging logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' ) logger = logging.getLogger(__name__) @app.command() def train( config_file: Optional[Path] = typer.Option( None, "--config", "-c", help="Path to configuration file", exists=True ), model: Optional[str] = typer.Option( None, "--model", "-m", help="Base model name or path" ), data: Optional[Path] = typer.Option( None, "--data", "-d", help="Path to training data" ), output: Path = typer.Option( "./output", "--output", "-o", help="Output directory" ), epochs: int = typer.Option(3, "--epochs", "-e", help="Number of epochs"), batch_size: int = typer.Option(4, "--batch-size", "-b", help="Batch size"), use_lora: bool = typer.Option(True, "--lora/--no-lora", help="Use LoRA"), use_quantization: bool = typer.Option( False, "--quantize/--no-quantize", help="Use quantization" ), ): """ Train a language model with custom data. """ try: console.print(Panel.fit( "[bold cyan]KerdosAI Training[/bold cyan]", subtitle="Universal LLM Training Agent" )) # Load configuration if config_file: console.print(f"šŸ“„ Loading configuration from: [cyan]{config_file}[/cyan]") config = load_config(config_file) else: console.print("āš™ļø Using default configuration") config = load_config() # Override with CLI arguments if model: config.base_model = model if data: config.data.train_file = str(data) config.output_dir = str(output) config.training.epochs = epochs config.training.batch_size = batch_size config.lora.enabled = use_lora config.quantization.enabled = use_quantization # Validate configuration config.validate_compatibility() # Display configuration _display_config(config) # Import here to avoid slow startup from agent import KerdosAgent with Progress( SpinnerColumn(), TextColumn("[progress.description]{task.description}"), console=console ) as progress: # Initialize agent task = progress.add_task("Initializing agent...", total=None) agent = KerdosAgent( base_model=config.base_model, training_data=config.data.train_file, **config.model_dump() ) progress.update(task, description="āœ“ Agent initialized") # Prepare for training task = progress.add_task("Preparing model for training...", total=None) if config.lora.enabled or config.quantization.enabled: agent.prepare_for_training( use_lora=config.lora.enabled, lora_r=config.lora.r, lora_alpha=config.lora.alpha, use_4bit=config.quantization.enabled and config.quantization.bits == 4, use_8bit=config.quantization.enabled and config.quantization.bits == 8 ) progress.update(task, description="āœ“ Model prepared") # Train model console.print("\nšŸš€ Starting training...") metrics = agent.train( epochs=config.training.epochs, batch_size=config.training.batch_size, learning_rate=config.training.learning_rate ) # Save model console.print(f"\nšŸ’¾ Saving model to: [cyan]{config.output_dir}[/cyan]") agent.save(config.output_dir) # Display results console.print("\n[bold green]āœ“ Training completed successfully![/bold green]") _display_metrics(metrics) except KerdosError as e: console.print(f"\n[bold red]Error:[/bold red] {e}") raise typer.Exit(code=1) except Exception as e: console.print(f"\n[bold red]Unexpected error:[/bold red] {e}") logger.exception("Training failed") raise typer.Exit(code=1) @app.command() def generate( model_dir: Path = typer.Argument(..., help="Path to trained model"), prompt: str = typer.Option(..., "--prompt", "-p", help="Input prompt"), max_length: int = typer.Option(100, "--max-length", help="Maximum length"), temperature: float = typer.Option(0.7, "--temperature", "-t", help="Temperature"), ): """ Generate text from a trained model. """ try: console.print(Panel.fit("[bold cyan]KerdosAI Generation[/bold cyan]")) # Import here to avoid slow startup from agent import KerdosAgent with Progress( SpinnerColumn(), TextColumn("[progress.description]{task.description}"), console=console ) as progress: task = progress.add_task("Loading model...", total=None) agent = KerdosAgent.load(model_dir) progress.update(task, description="āœ“ Model loaded") task = progress.add_task("Generating...", total=None) output = agent.generate( prompt=prompt, max_length=max_length, temperature=temperature ) progress.update(task, description="āœ“ Generation complete") console.print("\n[bold]Generated Text:[/bold]") console.print(Panel(output, border_style="green")) except Exception as e: console.print(f"\n[bold red]Error:[/bold red] {e}") raise typer.Exit(code=1) @app.command() def info( model_dir: Optional[Path] = typer.Argument(None, help="Path to model (optional)"), ): """ Display model information. """ try: if model_dir: from agent import KerdosAgent console.print(f"šŸ“Š Loading model info from: [cyan]{model_dir}[/cyan]\n") agent = KerdosAgent.load(model_dir) info_dict = agent.get_model_info() table = Table(title="Model Information", show_header=True) table.add_column("Property", style="cyan") table.add_column("Value", style="green") for key, value in info_dict.items(): if key == "trainable_percentage": table.add_row(key, f"{value:.2f}%") elif isinstance(value, (int, float)): table.add_row(key, f"{value:,}") else: table.add_row(key, str(value)) console.print(table) else: console.print(Panel.fit( "[bold cyan]KerdosAI[/bold cyan]\n" "Version: 0.2.0\n" "Universal LLM Training Agent", title="About" )) except Exception as e: console.print(f"\n[bold red]Error:[/bold red] {e}") raise typer.Exit(code=1) @app.command() def validate_config( config_file: Path = typer.Argument(..., help="Path to configuration file"), ): """ Validate a configuration file. """ try: console.print(f"šŸ” Validating configuration: [cyan]{config_file}[/cyan]\n") config = load_config(config_file) config.validate_compatibility() console.print("[bold green]āœ“ Configuration is valid![/bold green]\n") _display_config(config) except KerdosError as e: console.print(f"\n[bold red]Validation failed:[/bold red] {e}") raise typer.Exit(code=1) def _display_config(config: KerdosConfig): """Display configuration in a formatted table.""" table = Table(title="Configuration", show_header=True) table.add_column("Setting", style="cyan") table.add_column("Value", style="green") table.add_row("Base Model", config.base_model) table.add_row("Output Directory", config.output_dir) table.add_row("Epochs", str(config.training.epochs)) table.add_row("Batch Size", str(config.training.batch_size)) table.add_row("Learning Rate", f"{config.training.learning_rate:.2e}") table.add_row("LoRA Enabled", "āœ“" if config.lora.enabled else "āœ—") if config.lora.enabled: table.add_row(" LoRA Rank", str(config.lora.r)) table.add_row(" LoRA Alpha", str(config.lora.alpha)) table.add_row("Quantization", "āœ“" if config.quantization.enabled else "āœ—") if config.quantization.enabled: table.add_row(" Bits", str(config.quantization.bits)) console.print(table) def _display_metrics(metrics: dict): """Display training metrics in a formatted table.""" if not metrics: return table = Table(title="Training Metrics", show_header=True) table.add_column("Metric", style="cyan") table.add_column("Value", style="green") for key, value in metrics.items(): if isinstance(value, float): table.add_row(key, f"{value:.4f}") else: table.add_row(key, str(value)) console.print(table) def main(): """Main entry point.""" app() if __name__ == "__main__": main()