#!/usr/bin/env python3 """ Generate Training Data CLI Generate Q&A pairs from processed segments using Claude/GPT-4 API, then build the final training dataset in ChatML format. Usage: python scripts/generate_training_data.py --input data/processed/segments.json --output data/training/ Environment variables: ANTHROPIC_API_KEY - Required for Claude API OPENAI_API_KEY - Required for OpenAI API """ import argparse import json import os import sys from pathlib import Path # Add src to path for imports sys.path.insert(0, str(Path(__file__).parent.parent)) from rich.console import Console from rich.prompt import Confirm from rich.table import Table from src.data_processing.qa_generator import QAGenerator, QAPair from src.data_processing.dataset_builder import DatasetBuilder console = Console() def load_segments(path: Path) -> list: """Load segments from JSON file.""" with open(path, "r", encoding="utf-8") as f: data = json.load(f) # Convert to simple objects for the generator from dataclasses import dataclass @dataclass class Segment: content: str segment_index: int source_post_title: str return [ Segment( content=s["content"], segment_index=s.get("segment_index", i), source_post_title=s.get("source_post_title", "Unknown"), ) for i, s in enumerate(data) ] def main(): parser = argparse.ArgumentParser( description="Generate Q&A training data using LLM APIs", formatter_class=argparse.RawDescriptionHelpFormatter, epilog=""" Examples: # Generate 500 Q&A pairs using Claude python scripts/generate_training_data.py \\ --input data/processed/segments.json \\ --output data/training/ \\ --num-pairs 500 # Use OpenAI instead python scripts/generate_training_data.py \\ --input data/processed/segments.json \\ --output data/training/ \\ --provider openai # Just estimate cost without generating python scripts/generate_training_data.py \\ --input data/processed/segments.json \\ --estimate-only # Load existing Q&A pairs and just build dataset python scripts/generate_training_data.py \\ --qa-pairs data/processed/qa_pairs.json \\ --output data/training/ Environment variables: ANTHROPIC_API_KEY - Anthropic API key (for Claude) OPENAI_API_KEY - OpenAI API key (for GPT-4) """, ) # Input options input_group = parser.add_mutually_exclusive_group(required=True) input_group.add_argument( "--input", "-i", help="Path to segments.json file (will generate Q&A pairs)", ) input_group.add_argument( "--qa-pairs", help="Path to existing Q&A pairs JSON (skip generation)", ) # Output options parser.add_argument( "--output", "-o", default="data/training/", help="Output directory for training data (default: data/training/)", ) # Generation options parser.add_argument( "--num-pairs", type=int, default=500, help="Number of Q&A pairs to generate (default: 500)", ) parser.add_argument( "--questions-per-segment", type=int, default=3, help="Max questions per segment (default: 3)", ) parser.add_argument( "--provider", choices=["anthropic", "openai"], default="anthropic", help="LLM API provider (default: anthropic)", ) parser.add_argument( "--model", help="Model name (defaults: claude-sonnet-4-20250514 or gpt-4-turbo-preview)", ) parser.add_argument( "--requests-per-minute", type=int, default=20, help="Rate limit for API requests (default: 20)", ) # Dataset options parser.add_argument( "--train-ratio", type=float, default=0.9, help="Train/validation split ratio (default: 0.9)", ) parser.add_argument( "--max-tokens", type=int, default=2048, help="Maximum tokens per training example (default: 2048)", ) parser.add_argument( "--system-prompt-file", help="File containing custom system prompt", ) # Persona options parser.add_argument( "--ceo-name", default="Ryouken Okuni", help="CEO name for persona (default: Ryouken Okuni)", ) parser.add_argument( "--company-name", default="Akatsuki AI Technologies", help="Company name (default: Akatsuki AI Technologies)", ) # Other options parser.add_argument( "--estimate-only", action="store_true", help="Only estimate cost, don't generate", ) parser.add_argument( "--skip-generation", action="store_true", help="Skip Q&A generation, only build dataset from existing pairs", ) parser.add_argument( "--yes", "-y", action="store_true", help="Skip confirmation prompts", ) parser.add_argument( "--verbose", "-v", action="store_true", help="Verbose output", ) args = parser.parse_args() console.print("\n[bold blue]AI Executive - Training Data Generator[/bold blue]") console.print("=" * 50) # Create output directory output_dir = Path(args.output) output_dir.mkdir(parents=True, exist_ok=True) qa_pairs = [] # Load or generate Q&A pairs if args.qa_pairs: # Load existing Q&A pairs console.print(f"\n[yellow]Loading Q&A pairs from:[/yellow] {args.qa_pairs}") qa_pairs = QAGenerator.load_pairs(args.qa_pairs) console.print(f" [green]✓[/green] Loaded {len(qa_pairs)} Q&A pairs") elif args.input: # Check API key if args.provider == "anthropic": api_key = os.environ.get("ANTHROPIC_API_KEY") if not api_key: console.print("[red]Error:[/red] ANTHROPIC_API_KEY not found in environment") console.print("\nSet it with:") console.print(" export ANTHROPIC_API_KEY=your_key_here") return 1 else: api_key = os.environ.get("OPENAI_API_KEY") if not api_key: console.print("[red]Error:[/red] OPENAI_API_KEY not found in environment") console.print("\nSet it with:") console.print(" export OPENAI_API_KEY=your_key_here") return 1 # Load segments input_path = Path(args.input) if not input_path.exists(): console.print(f"[red]Error:[/red] Input file not found: {input_path}") return 1 console.print(f"\n[yellow]Loading segments from:[/yellow] {input_path}") segments = load_segments(input_path) console.print(f" [green]✓[/green] Loaded {len(segments)} segments") # Initialize generator try: generator = QAGenerator( provider=args.provider, model=args.model, requests_per_minute=args.requests_per_minute, ceo_name=args.ceo_name, company_name=args.company_name, ) except (ImportError, ValueError) as e: console.print(f"[red]Error initializing generator:[/red] {e}") return 1 # Show cost estimate estimate = generator.estimate_cost(args.num_pairs) console.print("\n[yellow]Cost Estimate[/yellow]") table = Table(show_header=False, box=None) table.add_column(style="dim") table.add_column(style="white") table.add_row("Provider:", estimate["provider"]) table.add_row("Model:", estimate["model"]) table.add_row("Input tokens:", f"{estimate['estimated_input_tokens']:,}") table.add_row("Output tokens:", f"{estimate['estimated_output_tokens']:,}") table.add_row("Estimated cost:", f"${estimate['estimated_cost_usd']:.2f}") console.print(table) if args.estimate_only: return 0 # Confirm generation if not args.yes: if not Confirm.ask("\nProceed with generation?"): console.print("[dim]Cancelled.[/dim]") return 0 # Generate Q&A pairs console.print(f"\n[yellow]Generating {args.num_pairs} Q&A pairs...[/yellow]") qa_pairs_path = output_dir / "qa_pairs.json" qa_pairs = generator.generate_from_segments( segments=segments, num_pairs=args.num_pairs, questions_per_segment=args.questions_per_segment, output_path=qa_pairs_path, ) # Show actual cost actual = generator.get_actual_cost() console.print(f"\n [green]✓[/green] Generated {len(qa_pairs)} Q&A pairs") console.print(f" [green]✓[/green] Actual cost: ${actual['actual_cost_usd']:.2f}") console.print(f" [green]✓[/green] Saved to: {qa_pairs_path}") if not qa_pairs: console.print("[red]Error:[/red] No Q&A pairs available") return 1 # Build training dataset console.print(f"\n[yellow]Building training dataset...[/yellow]") # Load custom system prompt if provided system_prompt = None if args.system_prompt_file: with open(args.system_prompt_file, "r", encoding="utf-8") as f: system_prompt = f.read().strip() console.print(f" [dim]Using custom system prompt from: {args.system_prompt_file}[/dim]") builder = DatasetBuilder( system_prompt=system_prompt, ceo_name=args.ceo_name, company_name=args.company_name, max_tokens_per_example=args.max_tokens, ) stats = builder.build_from_qa_pairs( qa_pairs=qa_pairs, output_dir=output_dir, train_ratio=args.train_ratio, ) # Show statistics console.print("\n[yellow]Dataset Statistics[/yellow]") table = Table(show_header=False, box=None) table.add_column(style="dim") table.add_column(style="white") table.add_row("Total examples:", str(stats.total_examples)) table.add_row("Train examples:", str(stats.train_examples)) table.add_row("Validation examples:", str(stats.validation_examples)) table.add_row("Avg tokens/example:", f"{stats.avg_tokens_per_example:.1f}") table.add_row("Token range:", f"{stats.min_tokens} - {stats.max_tokens}") table.add_row("Total tokens:", f"{stats.total_tokens:,}") console.print(table) if args.verbose: console.print("\n [dim]Question type distribution:[/dim]") for q_type, count in stats.question_type_distribution.items(): console.print(f" {q_type}: {count}") # Summary console.print("\n" + "=" * 50) console.print("[bold green]Training data generation complete![/bold green]") console.print(f"\nOutput files in: {output_dir}") console.print(f" - train.jsonl ({stats.train_examples} examples)") console.print(f" - validation.jsonl ({stats.validation_examples} examples)") console.print(" - dataset_stats.json") if args.input: console.print(" - qa_pairs.json") console.print("\n[dim]Next step: Fine-tune the voice model[/dim]") console.print(f"[dim] python scripts/train_model.py --dataset {output_dir / 'train.jsonl'}[/dim]") return 0 if __name__ == "__main__": exit(main())