Spaces:
Paused
Paused
| #!/usr/bin/env python3 | |
| """ | |
| Generate Training Data CLI | |
| Generate Q&A pairs from processed segments using Claude/GPT-4 API, | |
| then build the final training dataset in ChatML format. | |
| Usage: | |
| python scripts/generate_training_data.py --input data/processed/segments.json --output data/training/ | |
| Environment variables: | |
| ANTHROPIC_API_KEY - Required for Claude API | |
| OPENAI_API_KEY - Required for OpenAI API | |
| """ | |
| import argparse | |
| import json | |
| import os | |
| import sys | |
| from pathlib import Path | |
| # Add src to path for imports | |
| sys.path.insert(0, str(Path(__file__).parent.parent)) | |
| from rich.console import Console | |
| from rich.prompt import Confirm | |
| from rich.table import Table | |
| from src.data_processing.qa_generator import QAGenerator, QAPair | |
| from src.data_processing.dataset_builder import DatasetBuilder | |
| console = Console() | |
| def load_segments(path: Path) -> list: | |
| """Load segments from JSON file.""" | |
| with open(path, "r", encoding="utf-8") as f: | |
| data = json.load(f) | |
| # Convert to simple objects for the generator | |
| from dataclasses import dataclass | |
| class Segment: | |
| content: str | |
| segment_index: int | |
| source_post_title: str | |
| return [ | |
| Segment( | |
| content=s["content"], | |
| segment_index=s.get("segment_index", i), | |
| source_post_title=s.get("source_post_title", "Unknown"), | |
| ) | |
| for i, s in enumerate(data) | |
| ] | |
| def main(): | |
| parser = argparse.ArgumentParser( | |
| description="Generate Q&A training data using LLM APIs", | |
| formatter_class=argparse.RawDescriptionHelpFormatter, | |
| epilog=""" | |
| Examples: | |
| # Generate 500 Q&A pairs using Claude | |
| python scripts/generate_training_data.py \\ | |
| --input data/processed/segments.json \\ | |
| --output data/training/ \\ | |
| --num-pairs 500 | |
| # Use OpenAI instead | |
| python scripts/generate_training_data.py \\ | |
| --input data/processed/segments.json \\ | |
| --output data/training/ \\ | |
| --provider openai | |
| # Just estimate cost without generating | |
| python scripts/generate_training_data.py \\ | |
| --input data/processed/segments.json \\ | |
| --estimate-only | |
| # Load existing Q&A pairs and just build dataset | |
| python scripts/generate_training_data.py \\ | |
| --qa-pairs data/processed/qa_pairs.json \\ | |
| --output data/training/ | |
| Environment variables: | |
| ANTHROPIC_API_KEY - Anthropic API key (for Claude) | |
| OPENAI_API_KEY - OpenAI API key (for GPT-4) | |
| """, | |
| ) | |
| # Input options | |
| input_group = parser.add_mutually_exclusive_group(required=True) | |
| input_group.add_argument( | |
| "--input", "-i", | |
| help="Path to segments.json file (will generate Q&A pairs)", | |
| ) | |
| input_group.add_argument( | |
| "--qa-pairs", | |
| help="Path to existing Q&A pairs JSON (skip generation)", | |
| ) | |
| # Output options | |
| parser.add_argument( | |
| "--output", "-o", | |
| default="data/training/", | |
| help="Output directory for training data (default: data/training/)", | |
| ) | |
| # Generation options | |
| parser.add_argument( | |
| "--num-pairs", | |
| type=int, | |
| default=500, | |
| help="Number of Q&A pairs to generate (default: 500)", | |
| ) | |
| parser.add_argument( | |
| "--questions-per-segment", | |
| type=int, | |
| default=3, | |
| help="Max questions per segment (default: 3)", | |
| ) | |
| parser.add_argument( | |
| "--provider", | |
| choices=["anthropic", "openai"], | |
| default="anthropic", | |
| help="LLM API provider (default: anthropic)", | |
| ) | |
| parser.add_argument( | |
| "--model", | |
| help="Model name (defaults: claude-sonnet-4-20250514 or gpt-4-turbo-preview)", | |
| ) | |
| parser.add_argument( | |
| "--requests-per-minute", | |
| type=int, | |
| default=20, | |
| help="Rate limit for API requests (default: 20)", | |
| ) | |
| # Dataset options | |
| parser.add_argument( | |
| "--train-ratio", | |
| type=float, | |
| default=0.9, | |
| help="Train/validation split ratio (default: 0.9)", | |
| ) | |
| parser.add_argument( | |
| "--max-tokens", | |
| type=int, | |
| default=2048, | |
| help="Maximum tokens per training example (default: 2048)", | |
| ) | |
| parser.add_argument( | |
| "--system-prompt-file", | |
| help="File containing custom system prompt", | |
| ) | |
| # Persona options | |
| parser.add_argument( | |
| "--ceo-name", | |
| default="Ryouken Okuni", | |
| help="CEO name for persona (default: Ryouken Okuni)", | |
| ) | |
| parser.add_argument( | |
| "--company-name", | |
| default="Akatsuki AI Technologies", | |
| help="Company name (default: Akatsuki AI Technologies)", | |
| ) | |
| # Other options | |
| parser.add_argument( | |
| "--estimate-only", | |
| action="store_true", | |
| help="Only estimate cost, don't generate", | |
| ) | |
| parser.add_argument( | |
| "--skip-generation", | |
| action="store_true", | |
| help="Skip Q&A generation, only build dataset from existing pairs", | |
| ) | |
| parser.add_argument( | |
| "--yes", "-y", | |
| action="store_true", | |
| help="Skip confirmation prompts", | |
| ) | |
| parser.add_argument( | |
| "--verbose", "-v", | |
| action="store_true", | |
| help="Verbose output", | |
| ) | |
| args = parser.parse_args() | |
| console.print("\n[bold blue]AI Executive - Training Data Generator[/bold blue]") | |
| console.print("=" * 50) | |
| # Create output directory | |
| output_dir = Path(args.output) | |
| output_dir.mkdir(parents=True, exist_ok=True) | |
| qa_pairs = [] | |
| # Load or generate Q&A pairs | |
| if args.qa_pairs: | |
| # Load existing Q&A pairs | |
| console.print(f"\n[yellow]Loading Q&A pairs from:[/yellow] {args.qa_pairs}") | |
| qa_pairs = QAGenerator.load_pairs(args.qa_pairs) | |
| console.print(f" [green]✓[/green] Loaded {len(qa_pairs)} Q&A pairs") | |
| elif args.input: | |
| # Check API key | |
| if args.provider == "anthropic": | |
| api_key = os.environ.get("ANTHROPIC_API_KEY") | |
| if not api_key: | |
| console.print("[red]Error:[/red] ANTHROPIC_API_KEY not found in environment") | |
| console.print("\nSet it with:") | |
| console.print(" export ANTHROPIC_API_KEY=your_key_here") | |
| return 1 | |
| else: | |
| api_key = os.environ.get("OPENAI_API_KEY") | |
| if not api_key: | |
| console.print("[red]Error:[/red] OPENAI_API_KEY not found in environment") | |
| console.print("\nSet it with:") | |
| console.print(" export OPENAI_API_KEY=your_key_here") | |
| return 1 | |
| # Load segments | |
| input_path = Path(args.input) | |
| if not input_path.exists(): | |
| console.print(f"[red]Error:[/red] Input file not found: {input_path}") | |
| return 1 | |
| console.print(f"\n[yellow]Loading segments from:[/yellow] {input_path}") | |
| segments = load_segments(input_path) | |
| console.print(f" [green]✓[/green] Loaded {len(segments)} segments") | |
| # Initialize generator | |
| try: | |
| generator = QAGenerator( | |
| provider=args.provider, | |
| model=args.model, | |
| requests_per_minute=args.requests_per_minute, | |
| ceo_name=args.ceo_name, | |
| company_name=args.company_name, | |
| ) | |
| except (ImportError, ValueError) as e: | |
| console.print(f"[red]Error initializing generator:[/red] {e}") | |
| return 1 | |
| # Show cost estimate | |
| estimate = generator.estimate_cost(args.num_pairs) | |
| console.print("\n[yellow]Cost Estimate[/yellow]") | |
| table = Table(show_header=False, box=None) | |
| table.add_column(style="dim") | |
| table.add_column(style="white") | |
| table.add_row("Provider:", estimate["provider"]) | |
| table.add_row("Model:", estimate["model"]) | |
| table.add_row("Input tokens:", f"{estimate['estimated_input_tokens']:,}") | |
| table.add_row("Output tokens:", f"{estimate['estimated_output_tokens']:,}") | |
| table.add_row("Estimated cost:", f"${estimate['estimated_cost_usd']:.2f}") | |
| console.print(table) | |
| if args.estimate_only: | |
| return 0 | |
| # Confirm generation | |
| if not args.yes: | |
| if not Confirm.ask("\nProceed with generation?"): | |
| console.print("[dim]Cancelled.[/dim]") | |
| return 0 | |
| # Generate Q&A pairs | |
| console.print(f"\n[yellow]Generating {args.num_pairs} Q&A pairs...[/yellow]") | |
| qa_pairs_path = output_dir / "qa_pairs.json" | |
| qa_pairs = generator.generate_from_segments( | |
| segments=segments, | |
| num_pairs=args.num_pairs, | |
| questions_per_segment=args.questions_per_segment, | |
| output_path=qa_pairs_path, | |
| ) | |
| # Show actual cost | |
| actual = generator.get_actual_cost() | |
| console.print(f"\n [green]✓[/green] Generated {len(qa_pairs)} Q&A pairs") | |
| console.print(f" [green]✓[/green] Actual cost: ${actual['actual_cost_usd']:.2f}") | |
| console.print(f" [green]✓[/green] Saved to: {qa_pairs_path}") | |
| if not qa_pairs: | |
| console.print("[red]Error:[/red] No Q&A pairs available") | |
| return 1 | |
| # Build training dataset | |
| console.print(f"\n[yellow]Building training dataset...[/yellow]") | |
| # Load custom system prompt if provided | |
| system_prompt = None | |
| if args.system_prompt_file: | |
| with open(args.system_prompt_file, "r", encoding="utf-8") as f: | |
| system_prompt = f.read().strip() | |
| console.print(f" [dim]Using custom system prompt from: {args.system_prompt_file}[/dim]") | |
| builder = DatasetBuilder( | |
| system_prompt=system_prompt, | |
| ceo_name=args.ceo_name, | |
| company_name=args.company_name, | |
| max_tokens_per_example=args.max_tokens, | |
| ) | |
| stats = builder.build_from_qa_pairs( | |
| qa_pairs=qa_pairs, | |
| output_dir=output_dir, | |
| train_ratio=args.train_ratio, | |
| ) | |
| # Show statistics | |
| console.print("\n[yellow]Dataset Statistics[/yellow]") | |
| table = Table(show_header=False, box=None) | |
| table.add_column(style="dim") | |
| table.add_column(style="white") | |
| table.add_row("Total examples:", str(stats.total_examples)) | |
| table.add_row("Train examples:", str(stats.train_examples)) | |
| table.add_row("Validation examples:", str(stats.validation_examples)) | |
| table.add_row("Avg tokens/example:", f"{stats.avg_tokens_per_example:.1f}") | |
| table.add_row("Token range:", f"{stats.min_tokens} - {stats.max_tokens}") | |
| table.add_row("Total tokens:", f"{stats.total_tokens:,}") | |
| console.print(table) | |
| if args.verbose: | |
| console.print("\n [dim]Question type distribution:[/dim]") | |
| for q_type, count in stats.question_type_distribution.items(): | |
| console.print(f" {q_type}: {count}") | |
| # Summary | |
| console.print("\n" + "=" * 50) | |
| console.print("[bold green]Training data generation complete![/bold green]") | |
| console.print(f"\nOutput files in: {output_dir}") | |
| console.print(f" - train.jsonl ({stats.train_examples} examples)") | |
| console.print(f" - validation.jsonl ({stats.validation_examples} examples)") | |
| console.print(" - dataset_stats.json") | |
| if args.input: | |
| console.print(" - qa_pairs.json") | |
| console.print("\n[dim]Next step: Fine-tune the voice model[/dim]") | |
| console.print(f"[dim] python scripts/train_model.py --dataset {output_dir / 'train.jsonl'}[/dim]") | |
| return 0 | |
| if __name__ == "__main__": | |
| exit(main()) | |