kerdosai / cli.py
Anonymous Hunter
feat: Add robust configuration management, Docker support, initial testing, and quickstart documentation.
f21249a
"""
Enhanced CLI for KerdosAI with rich output and better UX.
"""
import typer
from typing import Optional
from pathlib import Path
from rich.console import Console
from rich.table import Table
from rich.progress import Progress, SpinnerColumn, TextColumn
from rich.panel import Panel
from rich import print as rprint
import logging
from config import load_config, KerdosConfig
from exceptions import KerdosError
# Initialize Typer app and Rich console
app = typer.Typer(
name="kerdosai",
help="KerdosAI - Universal LLM Training Agent",
add_completion=False
)
console = Console()
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
@app.command()
def train(
config_file: Optional[Path] = typer.Option(
None,
"--config",
"-c",
help="Path to configuration file",
exists=True
),
model: Optional[str] = typer.Option(
None,
"--model",
"-m",
help="Base model name or path"
),
data: Optional[Path] = typer.Option(
None,
"--data",
"-d",
help="Path to training data"
),
output: Path = typer.Option(
"./output",
"--output",
"-o",
help="Output directory"
),
epochs: int = typer.Option(3, "--epochs", "-e", help="Number of epochs"),
batch_size: int = typer.Option(4, "--batch-size", "-b", help="Batch size"),
use_lora: bool = typer.Option(True, "--lora/--no-lora", help="Use LoRA"),
use_quantization: bool = typer.Option(
False,
"--quantize/--no-quantize",
help="Use quantization"
),
):
"""
Train a language model with custom data.
"""
try:
console.print(Panel.fit(
"[bold cyan]KerdosAI Training[/bold cyan]",
subtitle="Universal LLM Training Agent"
))
# Load configuration
if config_file:
console.print(f"πŸ“„ Loading configuration from: [cyan]{config_file}[/cyan]")
config = load_config(config_file)
else:
console.print("βš™οΈ Using default configuration")
config = load_config()
# Override with CLI arguments
if model:
config.base_model = model
if data:
config.data.train_file = str(data)
config.output_dir = str(output)
config.training.epochs = epochs
config.training.batch_size = batch_size
config.lora.enabled = use_lora
config.quantization.enabled = use_quantization
# Validate configuration
config.validate_compatibility()
# Display configuration
_display_config(config)
# Import here to avoid slow startup
from agent import KerdosAgent
with Progress(
SpinnerColumn(),
TextColumn("[progress.description]{task.description}"),
console=console
) as progress:
# Initialize agent
task = progress.add_task("Initializing agent...", total=None)
agent = KerdosAgent(
base_model=config.base_model,
training_data=config.data.train_file,
**config.model_dump()
)
progress.update(task, description="βœ“ Agent initialized")
# Prepare for training
task = progress.add_task("Preparing model for training...", total=None)
if config.lora.enabled or config.quantization.enabled:
agent.prepare_for_training(
use_lora=config.lora.enabled,
lora_r=config.lora.r,
lora_alpha=config.lora.alpha,
use_4bit=config.quantization.enabled and config.quantization.bits == 4,
use_8bit=config.quantization.enabled and config.quantization.bits == 8
)
progress.update(task, description="βœ“ Model prepared")
# Train model
console.print("\nπŸš€ Starting training...")
metrics = agent.train(
epochs=config.training.epochs,
batch_size=config.training.batch_size,
learning_rate=config.training.learning_rate
)
# Save model
console.print(f"\nπŸ’Ύ Saving model to: [cyan]{config.output_dir}[/cyan]")
agent.save(config.output_dir)
# Display results
console.print("\n[bold green]βœ“ Training completed successfully![/bold green]")
_display_metrics(metrics)
except KerdosError as e:
console.print(f"\n[bold red]Error:[/bold red] {e}")
raise typer.Exit(code=1)
except Exception as e:
console.print(f"\n[bold red]Unexpected error:[/bold red] {e}")
logger.exception("Training failed")
raise typer.Exit(code=1)
@app.command()
def generate(
model_dir: Path = typer.Argument(..., help="Path to trained model"),
prompt: str = typer.Option(..., "--prompt", "-p", help="Input prompt"),
max_length: int = typer.Option(100, "--max-length", help="Maximum length"),
temperature: float = typer.Option(0.7, "--temperature", "-t", help="Temperature"),
):
"""
Generate text from a trained model.
"""
try:
console.print(Panel.fit("[bold cyan]KerdosAI Generation[/bold cyan]"))
# Import here to avoid slow startup
from agent import KerdosAgent
with Progress(
SpinnerColumn(),
TextColumn("[progress.description]{task.description}"),
console=console
) as progress:
task = progress.add_task("Loading model...", total=None)
agent = KerdosAgent.load(model_dir)
progress.update(task, description="βœ“ Model loaded")
task = progress.add_task("Generating...", total=None)
output = agent.generate(
prompt=prompt,
max_length=max_length,
temperature=temperature
)
progress.update(task, description="βœ“ Generation complete")
console.print("\n[bold]Generated Text:[/bold]")
console.print(Panel(output, border_style="green"))
except Exception as e:
console.print(f"\n[bold red]Error:[/bold red] {e}")
raise typer.Exit(code=1)
@app.command()
def info(
model_dir: Optional[Path] = typer.Argument(None, help="Path to model (optional)"),
):
"""
Display model information.
"""
try:
if model_dir:
from agent import KerdosAgent
console.print(f"πŸ“Š Loading model info from: [cyan]{model_dir}[/cyan]\n")
agent = KerdosAgent.load(model_dir)
info_dict = agent.get_model_info()
table = Table(title="Model Information", show_header=True)
table.add_column("Property", style="cyan")
table.add_column("Value", style="green")
for key, value in info_dict.items():
if key == "trainable_percentage":
table.add_row(key, f"{value:.2f}%")
elif isinstance(value, (int, float)):
table.add_row(key, f"{value:,}")
else:
table.add_row(key, str(value))
console.print(table)
else:
console.print(Panel.fit(
"[bold cyan]KerdosAI[/bold cyan]\n"
"Version: 0.2.0\n"
"Universal LLM Training Agent",
title="About"
))
except Exception as e:
console.print(f"\n[bold red]Error:[/bold red] {e}")
raise typer.Exit(code=1)
@app.command()
def validate_config(
config_file: Path = typer.Argument(..., help="Path to configuration file"),
):
"""
Validate a configuration file.
"""
try:
console.print(f"πŸ” Validating configuration: [cyan]{config_file}[/cyan]\n")
config = load_config(config_file)
config.validate_compatibility()
console.print("[bold green]βœ“ Configuration is valid![/bold green]\n")
_display_config(config)
except KerdosError as e:
console.print(f"\n[bold red]Validation failed:[/bold red] {e}")
raise typer.Exit(code=1)
def _display_config(config: KerdosConfig):
"""Display configuration in a formatted table."""
table = Table(title="Configuration", show_header=True)
table.add_column("Setting", style="cyan")
table.add_column("Value", style="green")
table.add_row("Base Model", config.base_model)
table.add_row("Output Directory", config.output_dir)
table.add_row("Epochs", str(config.training.epochs))
table.add_row("Batch Size", str(config.training.batch_size))
table.add_row("Learning Rate", f"{config.training.learning_rate:.2e}")
table.add_row("LoRA Enabled", "βœ“" if config.lora.enabled else "βœ—")
if config.lora.enabled:
table.add_row(" LoRA Rank", str(config.lora.r))
table.add_row(" LoRA Alpha", str(config.lora.alpha))
table.add_row("Quantization", "βœ“" if config.quantization.enabled else "βœ—")
if config.quantization.enabled:
table.add_row(" Bits", str(config.quantization.bits))
console.print(table)
def _display_metrics(metrics: dict):
"""Display training metrics in a formatted table."""
if not metrics:
return
table = Table(title="Training Metrics", show_header=True)
table.add_column("Metric", style="cyan")
table.add_column("Value", style="green")
for key, value in metrics.items():
if isinstance(value, float):
table.add_row(key, f"{value:.4f}")
else:
table.add_row(key, str(value))
console.print(table)
def main():
"""Main entry point."""
app()
if __name__ == "__main__":
main()