ai_exec / scripts /generate_training_data.py
Chaitanya-aitf's picture
Upload 38 files
45ee481 verified
#!/usr/bin/env python3
"""
Generate Training Data CLI
Generate Q&A pairs from processed segments using Claude/GPT-4 API,
then build the final training dataset in ChatML format.
Usage:
python scripts/generate_training_data.py --input data/processed/segments.json --output data/training/
Environment variables:
ANTHROPIC_API_KEY - Required for Claude API
OPENAI_API_KEY - Required for OpenAI API
"""
import argparse
import json
import os
import sys
from pathlib import Path
# Add src to path for imports
sys.path.insert(0, str(Path(__file__).parent.parent))
from rich.console import Console
from rich.prompt import Confirm
from rich.table import Table
from src.data_processing.qa_generator import QAGenerator, QAPair
from src.data_processing.dataset_builder import DatasetBuilder
console = Console()
def load_segments(path: Path) -> list:
"""Load segments from JSON file."""
with open(path, "r", encoding="utf-8") as f:
data = json.load(f)
# Convert to simple objects for the generator
from dataclasses import dataclass
@dataclass
class Segment:
content: str
segment_index: int
source_post_title: str
return [
Segment(
content=s["content"],
segment_index=s.get("segment_index", i),
source_post_title=s.get("source_post_title", "Unknown"),
)
for i, s in enumerate(data)
]
def main():
parser = argparse.ArgumentParser(
description="Generate Q&A training data using LLM APIs",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
# Generate 500 Q&A pairs using Claude
python scripts/generate_training_data.py \\
--input data/processed/segments.json \\
--output data/training/ \\
--num-pairs 500
# Use OpenAI instead
python scripts/generate_training_data.py \\
--input data/processed/segments.json \\
--output data/training/ \\
--provider openai
# Just estimate cost without generating
python scripts/generate_training_data.py \\
--input data/processed/segments.json \\
--estimate-only
# Load existing Q&A pairs and just build dataset
python scripts/generate_training_data.py \\
--qa-pairs data/processed/qa_pairs.json \\
--output data/training/
Environment variables:
ANTHROPIC_API_KEY - Anthropic API key (for Claude)
OPENAI_API_KEY - OpenAI API key (for GPT-4)
""",
)
# Input options
input_group = parser.add_mutually_exclusive_group(required=True)
input_group.add_argument(
"--input", "-i",
help="Path to segments.json file (will generate Q&A pairs)",
)
input_group.add_argument(
"--qa-pairs",
help="Path to existing Q&A pairs JSON (skip generation)",
)
# Output options
parser.add_argument(
"--output", "-o",
default="data/training/",
help="Output directory for training data (default: data/training/)",
)
# Generation options
parser.add_argument(
"--num-pairs",
type=int,
default=500,
help="Number of Q&A pairs to generate (default: 500)",
)
parser.add_argument(
"--questions-per-segment",
type=int,
default=3,
help="Max questions per segment (default: 3)",
)
parser.add_argument(
"--provider",
choices=["anthropic", "openai"],
default="anthropic",
help="LLM API provider (default: anthropic)",
)
parser.add_argument(
"--model",
help="Model name (defaults: claude-sonnet-4-20250514 or gpt-4-turbo-preview)",
)
parser.add_argument(
"--requests-per-minute",
type=int,
default=20,
help="Rate limit for API requests (default: 20)",
)
# Dataset options
parser.add_argument(
"--train-ratio",
type=float,
default=0.9,
help="Train/validation split ratio (default: 0.9)",
)
parser.add_argument(
"--max-tokens",
type=int,
default=2048,
help="Maximum tokens per training example (default: 2048)",
)
parser.add_argument(
"--system-prompt-file",
help="File containing custom system prompt",
)
# Persona options
parser.add_argument(
"--ceo-name",
default="Ryouken Okuni",
help="CEO name for persona (default: Ryouken Okuni)",
)
parser.add_argument(
"--company-name",
default="Akatsuki AI Technologies",
help="Company name (default: Akatsuki AI Technologies)",
)
# Other options
parser.add_argument(
"--estimate-only",
action="store_true",
help="Only estimate cost, don't generate",
)
parser.add_argument(
"--skip-generation",
action="store_true",
help="Skip Q&A generation, only build dataset from existing pairs",
)
parser.add_argument(
"--yes", "-y",
action="store_true",
help="Skip confirmation prompts",
)
parser.add_argument(
"--verbose", "-v",
action="store_true",
help="Verbose output",
)
args = parser.parse_args()
console.print("\n[bold blue]AI Executive - Training Data Generator[/bold blue]")
console.print("=" * 50)
# Create output directory
output_dir = Path(args.output)
output_dir.mkdir(parents=True, exist_ok=True)
qa_pairs = []
# Load or generate Q&A pairs
if args.qa_pairs:
# Load existing Q&A pairs
console.print(f"\n[yellow]Loading Q&A pairs from:[/yellow] {args.qa_pairs}")
qa_pairs = QAGenerator.load_pairs(args.qa_pairs)
console.print(f" [green]✓[/green] Loaded {len(qa_pairs)} Q&A pairs")
elif args.input:
# Check API key
if args.provider == "anthropic":
api_key = os.environ.get("ANTHROPIC_API_KEY")
if not api_key:
console.print("[red]Error:[/red] ANTHROPIC_API_KEY not found in environment")
console.print("\nSet it with:")
console.print(" export ANTHROPIC_API_KEY=your_key_here")
return 1
else:
api_key = os.environ.get("OPENAI_API_KEY")
if not api_key:
console.print("[red]Error:[/red] OPENAI_API_KEY not found in environment")
console.print("\nSet it with:")
console.print(" export OPENAI_API_KEY=your_key_here")
return 1
# Load segments
input_path = Path(args.input)
if not input_path.exists():
console.print(f"[red]Error:[/red] Input file not found: {input_path}")
return 1
console.print(f"\n[yellow]Loading segments from:[/yellow] {input_path}")
segments = load_segments(input_path)
console.print(f" [green]✓[/green] Loaded {len(segments)} segments")
# Initialize generator
try:
generator = QAGenerator(
provider=args.provider,
model=args.model,
requests_per_minute=args.requests_per_minute,
ceo_name=args.ceo_name,
company_name=args.company_name,
)
except (ImportError, ValueError) as e:
console.print(f"[red]Error initializing generator:[/red] {e}")
return 1
# Show cost estimate
estimate = generator.estimate_cost(args.num_pairs)
console.print("\n[yellow]Cost Estimate[/yellow]")
table = Table(show_header=False, box=None)
table.add_column(style="dim")
table.add_column(style="white")
table.add_row("Provider:", estimate["provider"])
table.add_row("Model:", estimate["model"])
table.add_row("Input tokens:", f"{estimate['estimated_input_tokens']:,}")
table.add_row("Output tokens:", f"{estimate['estimated_output_tokens']:,}")
table.add_row("Estimated cost:", f"${estimate['estimated_cost_usd']:.2f}")
console.print(table)
if args.estimate_only:
return 0
# Confirm generation
if not args.yes:
if not Confirm.ask("\nProceed with generation?"):
console.print("[dim]Cancelled.[/dim]")
return 0
# Generate Q&A pairs
console.print(f"\n[yellow]Generating {args.num_pairs} Q&A pairs...[/yellow]")
qa_pairs_path = output_dir / "qa_pairs.json"
qa_pairs = generator.generate_from_segments(
segments=segments,
num_pairs=args.num_pairs,
questions_per_segment=args.questions_per_segment,
output_path=qa_pairs_path,
)
# Show actual cost
actual = generator.get_actual_cost()
console.print(f"\n [green]✓[/green] Generated {len(qa_pairs)} Q&A pairs")
console.print(f" [green]✓[/green] Actual cost: ${actual['actual_cost_usd']:.2f}")
console.print(f" [green]✓[/green] Saved to: {qa_pairs_path}")
if not qa_pairs:
console.print("[red]Error:[/red] No Q&A pairs available")
return 1
# Build training dataset
console.print(f"\n[yellow]Building training dataset...[/yellow]")
# Load custom system prompt if provided
system_prompt = None
if args.system_prompt_file:
with open(args.system_prompt_file, "r", encoding="utf-8") as f:
system_prompt = f.read().strip()
console.print(f" [dim]Using custom system prompt from: {args.system_prompt_file}[/dim]")
builder = DatasetBuilder(
system_prompt=system_prompt,
ceo_name=args.ceo_name,
company_name=args.company_name,
max_tokens_per_example=args.max_tokens,
)
stats = builder.build_from_qa_pairs(
qa_pairs=qa_pairs,
output_dir=output_dir,
train_ratio=args.train_ratio,
)
# Show statistics
console.print("\n[yellow]Dataset Statistics[/yellow]")
table = Table(show_header=False, box=None)
table.add_column(style="dim")
table.add_column(style="white")
table.add_row("Total examples:", str(stats.total_examples))
table.add_row("Train examples:", str(stats.train_examples))
table.add_row("Validation examples:", str(stats.validation_examples))
table.add_row("Avg tokens/example:", f"{stats.avg_tokens_per_example:.1f}")
table.add_row("Token range:", f"{stats.min_tokens} - {stats.max_tokens}")
table.add_row("Total tokens:", f"{stats.total_tokens:,}")
console.print(table)
if args.verbose:
console.print("\n [dim]Question type distribution:[/dim]")
for q_type, count in stats.question_type_distribution.items():
console.print(f" {q_type}: {count}")
# Summary
console.print("\n" + "=" * 50)
console.print("[bold green]Training data generation complete![/bold green]")
console.print(f"\nOutput files in: {output_dir}")
console.print(f" - train.jsonl ({stats.train_examples} examples)")
console.print(f" - validation.jsonl ({stats.validation_examples} examples)")
console.print(" - dataset_stats.json")
if args.input:
console.print(" - qa_pairs.json")
console.print("\n[dim]Next step: Fine-tune the voice model[/dim]")
console.print(f"[dim] python scripts/train_model.py --dataset {output_dir / 'train.jsonl'}[/dim]")
return 0
if __name__ == "__main__":
exit(main())