File size: 3,760 Bytes
5dd1bb4 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 | """Variant generation orchestration for synthetic databases."""
from __future__ import annotations
import sqlite3
from dataclasses import dataclass
from pathlib import Path
from shutil import copy2
from .mutations import (
MutationResult,
detect_bridge_tables,
duplicate_bridge_rows,
get_table_schemas,
inject_irrelevant_rows,
remap_ids,
)
from .validate import validate_gold_sql
@dataclass
class VariantResult:
"""Result of generating a single synthetic database variant."""
variant_path: str
original_path: str
mutations_applied: list[MutationResult]
gold_sql_valid: bool
gold_answer: str | None
def generate_variant(
db_path: str,
gold_sql: str,
output_dir: str,
mutations: list[str] | None = None,
variant_id: int = 0,
) -> VariantResult:
"""Generate a single variant database and validate gold SQL against it."""
source_path = Path(db_path)
if not source_path.exists():
raise FileNotFoundError(f"Database does not exist: {db_path}")
output_path = Path(output_dir)
output_path.mkdir(parents=True, exist_ok=True)
variant_filename = f"{source_path.stem}_variant_{variant_id}.sqlite"
variant_path = output_path / variant_filename
copy2(source_path, variant_path)
schemas = get_table_schemas(str(variant_path))
bridge_tables = detect_bridge_tables(schemas)
available_mutations = {
"inject_irrelevant_rows": lambda: inject_irrelevant_rows(
str(variant_path), schemas
),
"remap_ids": lambda: remap_ids(str(variant_path), schemas),
"duplicate_bridge_rows": lambda: duplicate_bridge_rows(
str(variant_path), schemas, bridge_tables
),
}
selected_mutations = mutations or list(available_mutations)
unknown_mutations = [
name for name in selected_mutations if name not in available_mutations
]
if unknown_mutations:
known = ", ".join(sorted(available_mutations))
unknown = ", ".join(unknown_mutations)
raise ValueError(f"Unknown mutation(s): {unknown}. Valid mutations: {known}")
mutation_results: list[MutationResult] = []
for mutation_name in selected_mutations:
mutation_fn = available_mutations[mutation_name]
try:
mutation_results.append(mutation_fn())
except sqlite3.IntegrityError:
mutation_results.append(
MutationResult(
mutation_name=mutation_name,
tables_affected=[],
rows_added=0,
success=False,
)
)
break
try:
gold_sql_valid, gold_answer = validate_gold_sql(str(variant_path), gold_sql)
except sqlite3.OperationalError:
gold_sql_valid, gold_answer = False, None
if not gold_sql_valid and variant_path.exists():
variant_path.unlink()
return VariantResult(
variant_path=str(variant_path),
original_path=str(source_path),
mutations_applied=mutation_results,
gold_sql_valid=gold_sql_valid,
gold_answer=gold_answer,
)
def generate_variants_for_question(
db_path: str,
gold_sql: str,
output_dir: str,
n_variants: int = 2,
) -> list[VariantResult]:
"""Generate multiple variants and return only those that validate."""
if n_variants <= 0:
return []
variants: list[VariantResult] = []
for variant_id in range(n_variants):
result = generate_variant(
db_path=db_path,
gold_sql=gold_sql,
output_dir=output_dir,
variant_id=variant_id,
)
if result.gold_sql_valid:
variants.append(result)
return variants
|