File size: 7,309 Bytes
942050b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
"""BIRD Mini-Dev loader + deterministic dev sample.

Source layout (after `scripts/download_data.py bird-mini-dev`):

    data/bird_mini_dev/MINIDEV/
      mini_dev_sqlite.json       # 500 examples, schema documented below
      mini_dev_mysql.json        # 500 examples, MySQL dialect (same questions)
      mini_dev_postgresql.json   # 500 examples, PG dialect (same questions)
      dev_databases/<db>/<db>.sqlite

Each item:
    {
      "question_id": int,
      "db_id": str,
      "question": str,
      "evidence": str,           # BIRD calls this "external knowledge", a hint
      "SQL": str,                # gold SQL for the dialect
      "difficulty": "simple" | "moderate" | "challenging"
    }

Per docs/03_eval_methodology.md Β§5: this loader is *evaluation-only*. The
few-shot pool MUST come from a separate train split β€” never the dev file.
A leakage-check helper (`is_in_dev_split`) is exposed for tests that guard
the few-shot index.
"""

from __future__ import annotations

import json
import random
import re
from collections.abc import Iterable, Sequence
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Literal

import sqlglot
from sqlglot import expressions as exp

Difficulty = Literal["simple", "moderate", "challenging"]
Dialect = Literal["sqlite", "mysql", "postgresql"]

DEFAULT_BIRD_ROOT = Path("data") / "bird_mini_dev" / "MINIDEV"

_DIALECT_TO_FILE = {
    "sqlite": "mini_dev_sqlite.json",
    "mysql": "mini_dev_mysql.json",
    "postgresql": "mini_dev_postgresql.json",
}

# A tolerant table-name extractor used by `extract_gold_tables`. Matches
# `FROM <name>`, `JOIN <name>` (with optional schema prefix `db.`), and
# stops on whitespace or a comma. Aliases are dropped by design β€” gold tables
# are what we score, not aliases.
_TABLE_RE = re.compile(
    r"\b(?:FROM|JOIN)\s+(?:[A-Za-z_][\w]*\.)?([\"`']?)([A-Za-z_][\w]*)\1",
    re.IGNORECASE,
)


@dataclass(frozen=True, slots=True)
class BirdExample:
    """One BIRD Mini-Dev question + gold SQL + difficulty + db_id."""

    question_id: int
    db_id: str  # raw bird key, e.g. "debit_card_specializing"
    question: str
    evidence: str
    sql: str
    difficulty: Difficulty
    dialect: Dialect = "sqlite"

    @property
    def registry_db_id(self) -> str:
        """Registry id used by `nl_sql.db.registry` β€” `bird_<db_id>`."""
        return f"bird_{self.db_id}"


def load_bird_mini_dev(
    root: Path | str = DEFAULT_BIRD_ROOT,
    *,
    dialect: Dialect = "sqlite",
) -> list[BirdExample]:
    """Read the Mini-Dev json for one dialect, return all 500 examples."""
    path = Path(root) / _DIALECT_TO_FILE[dialect]
    if not path.is_file():
        raise FileNotFoundError(
            f"BIRD Mini-Dev file not found: {path}. "
            f"Run `python scripts/download_data.py bird-mini-dev` first."
        )
    with path.open("r", encoding="utf-8") as fh:
        raw = json.load(fh)
    return [_to_example(item, dialect=dialect) for item in raw]


def dev_split(
    examples: Sequence[BirdExample],
    *,
    n: int,
    seed: int = 0,
) -> list[BirdExample]:
    """Deterministic sample of `n` examples with stable-prefix property.

    Implementation: shuffle the pool once with `random.Random(seed)` and
    take the first `n`. This guarantees that for the same seed,
    `dev_split(..., n=k1)` is a prefix of `dev_split(..., n=k2)` whenever
    `k1 <= k2` β€” so growing the eval slice (50 β†’ 100 β†’ 200) re-uses every
    cached prompt from the smaller run instead of re-rolling.

    Result is sorted by question_id for reader stability (the underlying
    shuffle is unordered, but eval reports want stable IDs).
    """
    if n <= 0:
        return []
    pool = list(examples)
    if n >= len(pool):
        return sorted(pool, key=lambda e: e.question_id)
    rng = random.Random(seed)
    shuffled = pool[:]
    rng.shuffle(shuffled)
    chosen = shuffled[:n]
    return sorted(chosen, key=lambda e: e.question_id)


def extract_gold_tables(sql: str) -> list[str]:
    """Walk the SQL AST and collect every base-table reference.

    Used by Schema Recall@k. Captures tables referenced anywhere in the
    query β€” FROM, JOIN, correlated subqueries inside WHERE / SELECT,
    IN-list subqueries, set operations, etc. CTE names defined via
    ``WITH ... AS (...)`` are excluded because they shadow base tables
    in scope and would inflate recall against the schema_chunks index.

    Falls back to the FROM/JOIN regex if sqlglot can't parse the SQL β€”
    BIRD ships a small fraction of dialect-specific quirks that even
    the lenient parser may reject; better to under-count than crash.
    """
    try:
        tree = sqlglot.parse_one(sql, read="sqlite")
    except sqlglot.errors.ParseError:
        return _extract_via_regex(sql)
    if tree is None:
        return _extract_via_regex(sql)

    # CTE names live in a WITH block above the body β€” collect them so we
    # can drop matches that point at a CTE alias rather than a base table.
    cte_names: set[str] = {
        cte.alias_or_name.lower() for cte in tree.find_all(exp.CTE) if cte.alias_or_name
    }

    tables: list[str] = []
    seen: set[str] = set()
    for node in tree.find_all(exp.Table):
        # Walk up to detect tables that are themselves the alias side of
        # a CTE definition (the body of WITH x AS (...) β€” sqlglot models
        # the inner SELECT's tables here, which we still want; only skip
        # references whose .name matches a CTE alias).
        name = node.name
        if not name:
            continue
        key = name.lower()
        if key in cte_names:
            continue
        if key in seen:
            continue
        seen.add(key)
        tables.append(name)
    if not tables:
        return _extract_via_regex(sql)
    return tables


def _extract_via_regex(sql: str) -> list[str]:
    """Legacy regex-based fallback for the ~1% of SQLs sqlglot can't parse."""
    tables: list[str] = []
    seen: set[str] = set()
    for match in _TABLE_RE.finditer(sql):
        table = match.group(2)
        key = table.lower()
        if key in seen:
            continue
        seen.add(key)
        tables.append(table)
    return tables


def is_in_dev_split(question: str, dev_examples: Iterable[BirdExample]) -> bool:
    """Helper for the leakage-check CI test (`test_no_dev_in_fewshot`).

    Returns True iff `question` text exactly matches any dev example. Exact
    match is strict on purpose β€” paraphrases are NOT considered leakage,
    only verbatim copies (which is the actual risk when curating a few-shot
    pool from public sources).
    """
    needle = question.strip().lower()
    return any(ex.question.strip().lower() == needle for ex in dev_examples)


def _to_example(item: dict[str, Any], *, dialect: Dialect) -> BirdExample:
    difficulty = str(item.get("difficulty", "moderate"))
    if difficulty not in ("simple", "moderate", "challenging"):
        difficulty = "moderate"
    return BirdExample(
        question_id=int(item["question_id"]),
        db_id=str(item["db_id"]),
        question=str(item["question"]),
        evidence=str(item.get("evidence", "")),
        sql=str(item["SQL"]),
        difficulty=difficulty,  # type: ignore[arg-type]
        dialect=dialect,
    )