Anonymous Hunter
feat: Add robust configuration management, Docker support, initial testing, and quickstart documentation.
f21249a
| """ | |
| Enhanced CLI for KerdosAI with rich output and better UX. | |
| """ | |
| import typer | |
| from typing import Optional | |
| from pathlib import Path | |
| from rich.console import Console | |
| from rich.table import Table | |
| from rich.progress import Progress, SpinnerColumn, TextColumn | |
| from rich.panel import Panel | |
| from rich import print as rprint | |
| import logging | |
| from config import load_config, KerdosConfig | |
| from exceptions import KerdosError | |
| # Initialize Typer app and Rich console | |
| app = typer.Typer( | |
| name="kerdosai", | |
| help="KerdosAI - Universal LLM Training Agent", | |
| add_completion=False | |
| ) | |
| console = Console() | |
| # Configure logging | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' | |
| ) | |
| logger = logging.getLogger(__name__) | |
| def train( | |
| config_file: Optional[Path] = typer.Option( | |
| None, | |
| "--config", | |
| "-c", | |
| help="Path to configuration file", | |
| exists=True | |
| ), | |
| model: Optional[str] = typer.Option( | |
| None, | |
| "--model", | |
| "-m", | |
| help="Base model name or path" | |
| ), | |
| data: Optional[Path] = typer.Option( | |
| None, | |
| "--data", | |
| "-d", | |
| help="Path to training data" | |
| ), | |
| output: Path = typer.Option( | |
| "./output", | |
| "--output", | |
| "-o", | |
| help="Output directory" | |
| ), | |
| epochs: int = typer.Option(3, "--epochs", "-e", help="Number of epochs"), | |
| batch_size: int = typer.Option(4, "--batch-size", "-b", help="Batch size"), | |
| use_lora: bool = typer.Option(True, "--lora/--no-lora", help="Use LoRA"), | |
| use_quantization: bool = typer.Option( | |
| False, | |
| "--quantize/--no-quantize", | |
| help="Use quantization" | |
| ), | |
| ): | |
| """ | |
| Train a language model with custom data. | |
| """ | |
| try: | |
| console.print(Panel.fit( | |
| "[bold cyan]KerdosAI Training[/bold cyan]", | |
| subtitle="Universal LLM Training Agent" | |
| )) | |
| # Load configuration | |
| if config_file: | |
| console.print(f"π Loading configuration from: [cyan]{config_file}[/cyan]") | |
| config = load_config(config_file) | |
| else: | |
| console.print("βοΈ Using default configuration") | |
| config = load_config() | |
| # Override with CLI arguments | |
| if model: | |
| config.base_model = model | |
| if data: | |
| config.data.train_file = str(data) | |
| config.output_dir = str(output) | |
| config.training.epochs = epochs | |
| config.training.batch_size = batch_size | |
| config.lora.enabled = use_lora | |
| config.quantization.enabled = use_quantization | |
| # Validate configuration | |
| config.validate_compatibility() | |
| # Display configuration | |
| _display_config(config) | |
| # Import here to avoid slow startup | |
| from agent import KerdosAgent | |
| with Progress( | |
| SpinnerColumn(), | |
| TextColumn("[progress.description]{task.description}"), | |
| console=console | |
| ) as progress: | |
| # Initialize agent | |
| task = progress.add_task("Initializing agent...", total=None) | |
| agent = KerdosAgent( | |
| base_model=config.base_model, | |
| training_data=config.data.train_file, | |
| **config.model_dump() | |
| ) | |
| progress.update(task, description="β Agent initialized") | |
| # Prepare for training | |
| task = progress.add_task("Preparing model for training...", total=None) | |
| if config.lora.enabled or config.quantization.enabled: | |
| agent.prepare_for_training( | |
| use_lora=config.lora.enabled, | |
| lora_r=config.lora.r, | |
| lora_alpha=config.lora.alpha, | |
| use_4bit=config.quantization.enabled and config.quantization.bits == 4, | |
| use_8bit=config.quantization.enabled and config.quantization.bits == 8 | |
| ) | |
| progress.update(task, description="β Model prepared") | |
| # Train model | |
| console.print("\nπ Starting training...") | |
| metrics = agent.train( | |
| epochs=config.training.epochs, | |
| batch_size=config.training.batch_size, | |
| learning_rate=config.training.learning_rate | |
| ) | |
| # Save model | |
| console.print(f"\nπΎ Saving model to: [cyan]{config.output_dir}[/cyan]") | |
| agent.save(config.output_dir) | |
| # Display results | |
| console.print("\n[bold green]β Training completed successfully![/bold green]") | |
| _display_metrics(metrics) | |
| except KerdosError as e: | |
| console.print(f"\n[bold red]Error:[/bold red] {e}") | |
| raise typer.Exit(code=1) | |
| except Exception as e: | |
| console.print(f"\n[bold red]Unexpected error:[/bold red] {e}") | |
| logger.exception("Training failed") | |
| raise typer.Exit(code=1) | |
| def generate( | |
| model_dir: Path = typer.Argument(..., help="Path to trained model"), | |
| prompt: str = typer.Option(..., "--prompt", "-p", help="Input prompt"), | |
| max_length: int = typer.Option(100, "--max-length", help="Maximum length"), | |
| temperature: float = typer.Option(0.7, "--temperature", "-t", help="Temperature"), | |
| ): | |
| """ | |
| Generate text from a trained model. | |
| """ | |
| try: | |
| console.print(Panel.fit("[bold cyan]KerdosAI Generation[/bold cyan]")) | |
| # Import here to avoid slow startup | |
| from agent import KerdosAgent | |
| with Progress( | |
| SpinnerColumn(), | |
| TextColumn("[progress.description]{task.description}"), | |
| console=console | |
| ) as progress: | |
| task = progress.add_task("Loading model...", total=None) | |
| agent = KerdosAgent.load(model_dir) | |
| progress.update(task, description="β Model loaded") | |
| task = progress.add_task("Generating...", total=None) | |
| output = agent.generate( | |
| prompt=prompt, | |
| max_length=max_length, | |
| temperature=temperature | |
| ) | |
| progress.update(task, description="β Generation complete") | |
| console.print("\n[bold]Generated Text:[/bold]") | |
| console.print(Panel(output, border_style="green")) | |
| except Exception as e: | |
| console.print(f"\n[bold red]Error:[/bold red] {e}") | |
| raise typer.Exit(code=1) | |
| def info( | |
| model_dir: Optional[Path] = typer.Argument(None, help="Path to model (optional)"), | |
| ): | |
| """ | |
| Display model information. | |
| """ | |
| try: | |
| if model_dir: | |
| from agent import KerdosAgent | |
| console.print(f"π Loading model info from: [cyan]{model_dir}[/cyan]\n") | |
| agent = KerdosAgent.load(model_dir) | |
| info_dict = agent.get_model_info() | |
| table = Table(title="Model Information", show_header=True) | |
| table.add_column("Property", style="cyan") | |
| table.add_column("Value", style="green") | |
| for key, value in info_dict.items(): | |
| if key == "trainable_percentage": | |
| table.add_row(key, f"{value:.2f}%") | |
| elif isinstance(value, (int, float)): | |
| table.add_row(key, f"{value:,}") | |
| else: | |
| table.add_row(key, str(value)) | |
| console.print(table) | |
| else: | |
| console.print(Panel.fit( | |
| "[bold cyan]KerdosAI[/bold cyan]\n" | |
| "Version: 0.2.0\n" | |
| "Universal LLM Training Agent", | |
| title="About" | |
| )) | |
| except Exception as e: | |
| console.print(f"\n[bold red]Error:[/bold red] {e}") | |
| raise typer.Exit(code=1) | |
| def validate_config( | |
| config_file: Path = typer.Argument(..., help="Path to configuration file"), | |
| ): | |
| """ | |
| Validate a configuration file. | |
| """ | |
| try: | |
| console.print(f"π Validating configuration: [cyan]{config_file}[/cyan]\n") | |
| config = load_config(config_file) | |
| config.validate_compatibility() | |
| console.print("[bold green]β Configuration is valid![/bold green]\n") | |
| _display_config(config) | |
| except KerdosError as e: | |
| console.print(f"\n[bold red]Validation failed:[/bold red] {e}") | |
| raise typer.Exit(code=1) | |
| def _display_config(config: KerdosConfig): | |
| """Display configuration in a formatted table.""" | |
| table = Table(title="Configuration", show_header=True) | |
| table.add_column("Setting", style="cyan") | |
| table.add_column("Value", style="green") | |
| table.add_row("Base Model", config.base_model) | |
| table.add_row("Output Directory", config.output_dir) | |
| table.add_row("Epochs", str(config.training.epochs)) | |
| table.add_row("Batch Size", str(config.training.batch_size)) | |
| table.add_row("Learning Rate", f"{config.training.learning_rate:.2e}") | |
| table.add_row("LoRA Enabled", "β" if config.lora.enabled else "β") | |
| if config.lora.enabled: | |
| table.add_row(" LoRA Rank", str(config.lora.r)) | |
| table.add_row(" LoRA Alpha", str(config.lora.alpha)) | |
| table.add_row("Quantization", "β" if config.quantization.enabled else "β") | |
| if config.quantization.enabled: | |
| table.add_row(" Bits", str(config.quantization.bits)) | |
| console.print(table) | |
| def _display_metrics(metrics: dict): | |
| """Display training metrics in a formatted table.""" | |
| if not metrics: | |
| return | |
| table = Table(title="Training Metrics", show_header=True) | |
| table.add_column("Metric", style="cyan") | |
| table.add_column("Value", style="green") | |
| for key, value in metrics.items(): | |
| if isinstance(value, float): | |
| table.add_row(key, f"{value:.4f}") | |
| else: | |
| table.add_row(key, str(value)) | |
| console.print(table) | |
| def main(): | |
| """Main entry point.""" | |
| app() | |
| if __name__ == "__main__": | |
| main() |