sql_env / server /synthetic /generate.py
hjerpe's picture
Upload folder using huggingface_hub
5dd1bb4 verified
Raw
History Blame Contribute Delete
3.76 kB
"""Variant generation orchestration for synthetic databases."""
from __future__ import annotations
import sqlite3
from dataclasses import dataclass
from pathlib import Path
from shutil import copy2
from .mutations import (
MutationResult,
detect_bridge_tables,
duplicate_bridge_rows,
get_table_schemas,
inject_irrelevant_rows,
remap_ids,
)
from .validate import validate_gold_sql
@dataclass
class VariantResult:
"""Result of generating a single synthetic database variant."""
variant_path: str
original_path: str
mutations_applied: list[MutationResult]
gold_sql_valid: bool
gold_answer: str | None
def generate_variant(
db_path: str,
gold_sql: str,
output_dir: str,
mutations: list[str] | None = None,
variant_id: int = 0,
) -> VariantResult:
"""Generate a single variant database and validate gold SQL against it."""
source_path = Path(db_path)
if not source_path.exists():
raise FileNotFoundError(f"Database does not exist: {db_path}")
output_path = Path(output_dir)
output_path.mkdir(parents=True, exist_ok=True)
variant_filename = f"{source_path.stem}_variant_{variant_id}.sqlite"
variant_path = output_path / variant_filename
copy2(source_path, variant_path)
schemas = get_table_schemas(str(variant_path))
bridge_tables = detect_bridge_tables(schemas)
available_mutations = {
"inject_irrelevant_rows": lambda: inject_irrelevant_rows(
str(variant_path), schemas
),
"remap_ids": lambda: remap_ids(str(variant_path), schemas),
"duplicate_bridge_rows": lambda: duplicate_bridge_rows(
str(variant_path), schemas, bridge_tables
),
}
selected_mutations = mutations or list(available_mutations)
unknown_mutations = [
name for name in selected_mutations if name not in available_mutations
]
if unknown_mutations:
known = ", ".join(sorted(available_mutations))
unknown = ", ".join(unknown_mutations)
raise ValueError(f"Unknown mutation(s): {unknown}. Valid mutations: {known}")
mutation_results: list[MutationResult] = []
for mutation_name in selected_mutations:
mutation_fn = available_mutations[mutation_name]
try:
mutation_results.append(mutation_fn())
except sqlite3.IntegrityError:
mutation_results.append(
MutationResult(
mutation_name=mutation_name,
tables_affected=[],
rows_added=0,
success=False,
)
)
break
try:
gold_sql_valid, gold_answer = validate_gold_sql(str(variant_path), gold_sql)
except sqlite3.OperationalError:
gold_sql_valid, gold_answer = False, None
if not gold_sql_valid and variant_path.exists():
variant_path.unlink()
return VariantResult(
variant_path=str(variant_path),
original_path=str(source_path),
mutations_applied=mutation_results,
gold_sql_valid=gold_sql_valid,
gold_answer=gold_answer,
)
def generate_variants_for_question(
db_path: str,
gold_sql: str,
output_dir: str,
n_variants: int = 2,
) -> list[VariantResult]:
"""Generate multiple variants and return only those that validate."""
if n_variants <= 0:
return []
variants: list[VariantResult] = []
for variant_id in range(n_variants):
result = generate_variant(
db_path=db_path,
gold_sql=gold_sql,
output_dir=output_dir,
variant_id=variant_id,
)
if result.gold_sql_valid:
variants.append(result)
return variants