File size: 3,760 Bytes
5dd1bb4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
"""Variant generation orchestration for synthetic databases."""

from __future__ import annotations

import sqlite3
from dataclasses import dataclass
from pathlib import Path
from shutil import copy2

from .mutations import (
    MutationResult,
    detect_bridge_tables,
    duplicate_bridge_rows,
    get_table_schemas,
    inject_irrelevant_rows,
    remap_ids,
)
from .validate import validate_gold_sql


@dataclass
class VariantResult:
    """Result of generating a single synthetic database variant."""

    variant_path: str
    original_path: str
    mutations_applied: list[MutationResult]
    gold_sql_valid: bool
    gold_answer: str | None


def generate_variant(
    db_path: str,
    gold_sql: str,
    output_dir: str,
    mutations: list[str] | None = None,
    variant_id: int = 0,
) -> VariantResult:
    """Generate a single variant database and validate gold SQL against it."""

    source_path = Path(db_path)
    if not source_path.exists():
        raise FileNotFoundError(f"Database does not exist: {db_path}")

    output_path = Path(output_dir)
    output_path.mkdir(parents=True, exist_ok=True)

    variant_filename = f"{source_path.stem}_variant_{variant_id}.sqlite"
    variant_path = output_path / variant_filename
    copy2(source_path, variant_path)

    schemas = get_table_schemas(str(variant_path))
    bridge_tables = detect_bridge_tables(schemas)

    available_mutations = {
        "inject_irrelevant_rows": lambda: inject_irrelevant_rows(
            str(variant_path), schemas
        ),
        "remap_ids": lambda: remap_ids(str(variant_path), schemas),
        "duplicate_bridge_rows": lambda: duplicate_bridge_rows(
            str(variant_path), schemas, bridge_tables
        ),
    }

    selected_mutations = mutations or list(available_mutations)
    unknown_mutations = [
        name for name in selected_mutations if name not in available_mutations
    ]
    if unknown_mutations:
        known = ", ".join(sorted(available_mutations))
        unknown = ", ".join(unknown_mutations)
        raise ValueError(f"Unknown mutation(s): {unknown}. Valid mutations: {known}")

    mutation_results: list[MutationResult] = []
    for mutation_name in selected_mutations:
        mutation_fn = available_mutations[mutation_name]
        try:
            mutation_results.append(mutation_fn())
        except sqlite3.IntegrityError:
            mutation_results.append(
                MutationResult(
                    mutation_name=mutation_name,
                    tables_affected=[],
                    rows_added=0,
                    success=False,
                )
            )
            break

    try:
        gold_sql_valid, gold_answer = validate_gold_sql(str(variant_path), gold_sql)
    except sqlite3.OperationalError:
        gold_sql_valid, gold_answer = False, None

    if not gold_sql_valid and variant_path.exists():
        variant_path.unlink()

    return VariantResult(
        variant_path=str(variant_path),
        original_path=str(source_path),
        mutations_applied=mutation_results,
        gold_sql_valid=gold_sql_valid,
        gold_answer=gold_answer,
    )


def generate_variants_for_question(
    db_path: str,
    gold_sql: str,
    output_dir: str,
    n_variants: int = 2,
) -> list[VariantResult]:
    """Generate multiple variants and return only those that validate."""

    if n_variants <= 0:
        return []

    variants: list[VariantResult] = []
    for variant_id in range(n_variants):
        result = generate_variant(
            db_path=db_path,
            gold_sql=gold_sql,
            output_dir=output_dir,
            variant_id=variant_id,
        )
        if result.gold_sql_valid:
            variants.append(result)

    return variants