| """Variant generation orchestration for synthetic databases.""" |
|
|
| from __future__ import annotations |
|
|
| import sqlite3 |
| from dataclasses import dataclass |
| from pathlib import Path |
| from shutil import copy2 |
|
|
| from .mutations import ( |
| MutationResult, |
| detect_bridge_tables, |
| duplicate_bridge_rows, |
| get_table_schemas, |
| inject_irrelevant_rows, |
| remap_ids, |
| ) |
| from .validate import validate_gold_sql |
|
|
|
|
| @dataclass |
| class VariantResult: |
| """Result of generating a single synthetic database variant.""" |
|
|
| variant_path: str |
| original_path: str |
| mutations_applied: list[MutationResult] |
| gold_sql_valid: bool |
| gold_answer: str | None |
|
|
|
|
| def generate_variant( |
| db_path: str, |
| gold_sql: str, |
| output_dir: str, |
| mutations: list[str] | None = None, |
| variant_id: int = 0, |
| ) -> VariantResult: |
| """Generate a single variant database and validate gold SQL against it.""" |
|
|
| source_path = Path(db_path) |
| if not source_path.exists(): |
| raise FileNotFoundError(f"Database does not exist: {db_path}") |
|
|
| output_path = Path(output_dir) |
| output_path.mkdir(parents=True, exist_ok=True) |
|
|
| variant_filename = f"{source_path.stem}_variant_{variant_id}.sqlite" |
| variant_path = output_path / variant_filename |
| copy2(source_path, variant_path) |
|
|
| schemas = get_table_schemas(str(variant_path)) |
| bridge_tables = detect_bridge_tables(schemas) |
|
|
| available_mutations = { |
| "inject_irrelevant_rows": lambda: inject_irrelevant_rows( |
| str(variant_path), schemas |
| ), |
| "remap_ids": lambda: remap_ids(str(variant_path), schemas), |
| "duplicate_bridge_rows": lambda: duplicate_bridge_rows( |
| str(variant_path), schemas, bridge_tables |
| ), |
| } |
|
|
| selected_mutations = mutations or list(available_mutations) |
| unknown_mutations = [ |
| name for name in selected_mutations if name not in available_mutations |
| ] |
| if unknown_mutations: |
| known = ", ".join(sorted(available_mutations)) |
| unknown = ", ".join(unknown_mutations) |
| raise ValueError(f"Unknown mutation(s): {unknown}. Valid mutations: {known}") |
|
|
| mutation_results: list[MutationResult] = [] |
| for mutation_name in selected_mutations: |
| mutation_fn = available_mutations[mutation_name] |
| try: |
| mutation_results.append(mutation_fn()) |
| except sqlite3.IntegrityError: |
| mutation_results.append( |
| MutationResult( |
| mutation_name=mutation_name, |
| tables_affected=[], |
| rows_added=0, |
| success=False, |
| ) |
| ) |
| break |
|
|
| try: |
| gold_sql_valid, gold_answer = validate_gold_sql(str(variant_path), gold_sql) |
| except sqlite3.OperationalError: |
| gold_sql_valid, gold_answer = False, None |
|
|
| if not gold_sql_valid and variant_path.exists(): |
| variant_path.unlink() |
|
|
| return VariantResult( |
| variant_path=str(variant_path), |
| original_path=str(source_path), |
| mutations_applied=mutation_results, |
| gold_sql_valid=gold_sql_valid, |
| gold_answer=gold_answer, |
| ) |
|
|
|
|
| def generate_variants_for_question( |
| db_path: str, |
| gold_sql: str, |
| output_dir: str, |
| n_variants: int = 2, |
| ) -> list[VariantResult]: |
| """Generate multiple variants and return only those that validate.""" |
|
|
| if n_variants <= 0: |
| return [] |
|
|
| variants: list[VariantResult] = [] |
| for variant_id in range(n_variants): |
| result = generate_variant( |
| db_path=db_path, |
| gold_sql=gold_sql, |
| output_dir=output_dir, |
| variant_id=variant_id, |
| ) |
| if result.gold_sql_valid: |
| variants.append(result) |
|
|
| return variants |
|
|