"""Variant generation orchestration for synthetic databases.""" from __future__ import annotations import sqlite3 from dataclasses import dataclass from pathlib import Path from shutil import copy2 from .mutations import ( MutationResult, detect_bridge_tables, duplicate_bridge_rows, get_table_schemas, inject_irrelevant_rows, remap_ids, ) from .validate import validate_gold_sql @dataclass class VariantResult: """Result of generating a single synthetic database variant.""" variant_path: str original_path: str mutations_applied: list[MutationResult] gold_sql_valid: bool gold_answer: str | None def generate_variant( db_path: str, gold_sql: str, output_dir: str, mutations: list[str] | None = None, variant_id: int = 0, ) -> VariantResult: """Generate a single variant database and validate gold SQL against it.""" source_path = Path(db_path) if not source_path.exists(): raise FileNotFoundError(f"Database does not exist: {db_path}") output_path = Path(output_dir) output_path.mkdir(parents=True, exist_ok=True) variant_filename = f"{source_path.stem}_variant_{variant_id}.sqlite" variant_path = output_path / variant_filename copy2(source_path, variant_path) schemas = get_table_schemas(str(variant_path)) bridge_tables = detect_bridge_tables(schemas) available_mutations = { "inject_irrelevant_rows": lambda: inject_irrelevant_rows( str(variant_path), schemas ), "remap_ids": lambda: remap_ids(str(variant_path), schemas), "duplicate_bridge_rows": lambda: duplicate_bridge_rows( str(variant_path), schemas, bridge_tables ), } selected_mutations = mutations or list(available_mutations) unknown_mutations = [ name for name in selected_mutations if name not in available_mutations ] if unknown_mutations: known = ", ".join(sorted(available_mutations)) unknown = ", ".join(unknown_mutations) raise ValueError(f"Unknown mutation(s): {unknown}. Valid mutations: {known}") mutation_results: list[MutationResult] = [] for mutation_name in selected_mutations: mutation_fn = available_mutations[mutation_name] try: mutation_results.append(mutation_fn()) except sqlite3.IntegrityError: mutation_results.append( MutationResult( mutation_name=mutation_name, tables_affected=[], rows_added=0, success=False, ) ) break try: gold_sql_valid, gold_answer = validate_gold_sql(str(variant_path), gold_sql) except sqlite3.OperationalError: gold_sql_valid, gold_answer = False, None if not gold_sql_valid and variant_path.exists(): variant_path.unlink() return VariantResult( variant_path=str(variant_path), original_path=str(source_path), mutations_applied=mutation_results, gold_sql_valid=gold_sql_valid, gold_answer=gold_answer, ) def generate_variants_for_question( db_path: str, gold_sql: str, output_dir: str, n_variants: int = 2, ) -> list[VariantResult]: """Generate multiple variants and return only those that validate.""" if n_variants <= 0: return [] variants: list[VariantResult] = [] for variant_id in range(n_variants): result = generate_variant( db_path=db_path, gold_sql=gold_sql, output_dir=output_dir, variant_id=variant_id, ) if result.gold_sql_valid: variants.append(result) return variants