File size: 7,309 Bytes
942050b | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 | """BIRD Mini-Dev loader + deterministic dev sample.
Source layout (after `scripts/download_data.py bird-mini-dev`):
data/bird_mini_dev/MINIDEV/
mini_dev_sqlite.json # 500 examples, schema documented below
mini_dev_mysql.json # 500 examples, MySQL dialect (same questions)
mini_dev_postgresql.json # 500 examples, PG dialect (same questions)
dev_databases/<db>/<db>.sqlite
Each item:
{
"question_id": int,
"db_id": str,
"question": str,
"evidence": str, # BIRD calls this "external knowledge", a hint
"SQL": str, # gold SQL for the dialect
"difficulty": "simple" | "moderate" | "challenging"
}
Per docs/03_eval_methodology.md Β§5: this loader is *evaluation-only*. The
few-shot pool MUST come from a separate train split β never the dev file.
A leakage-check helper (`is_in_dev_split`) is exposed for tests that guard
the few-shot index.
"""
from __future__ import annotations
import json
import random
import re
from collections.abc import Iterable, Sequence
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Literal
import sqlglot
from sqlglot import expressions as exp
Difficulty = Literal["simple", "moderate", "challenging"]
Dialect = Literal["sqlite", "mysql", "postgresql"]
DEFAULT_BIRD_ROOT = Path("data") / "bird_mini_dev" / "MINIDEV"
_DIALECT_TO_FILE = {
"sqlite": "mini_dev_sqlite.json",
"mysql": "mini_dev_mysql.json",
"postgresql": "mini_dev_postgresql.json",
}
# A tolerant table-name extractor used by `extract_gold_tables`. Matches
# `FROM <name>`, `JOIN <name>` (with optional schema prefix `db.`), and
# stops on whitespace or a comma. Aliases are dropped by design β gold tables
# are what we score, not aliases.
_TABLE_RE = re.compile(
r"\b(?:FROM|JOIN)\s+(?:[A-Za-z_][\w]*\.)?([\"`']?)([A-Za-z_][\w]*)\1",
re.IGNORECASE,
)
@dataclass(frozen=True, slots=True)
class BirdExample:
"""One BIRD Mini-Dev question + gold SQL + difficulty + db_id."""
question_id: int
db_id: str # raw bird key, e.g. "debit_card_specializing"
question: str
evidence: str
sql: str
difficulty: Difficulty
dialect: Dialect = "sqlite"
@property
def registry_db_id(self) -> str:
"""Registry id used by `nl_sql.db.registry` β `bird_<db_id>`."""
return f"bird_{self.db_id}"
def load_bird_mini_dev(
root: Path | str = DEFAULT_BIRD_ROOT,
*,
dialect: Dialect = "sqlite",
) -> list[BirdExample]:
"""Read the Mini-Dev json for one dialect, return all 500 examples."""
path = Path(root) / _DIALECT_TO_FILE[dialect]
if not path.is_file():
raise FileNotFoundError(
f"BIRD Mini-Dev file not found: {path}. "
f"Run `python scripts/download_data.py bird-mini-dev` first."
)
with path.open("r", encoding="utf-8") as fh:
raw = json.load(fh)
return [_to_example(item, dialect=dialect) for item in raw]
def dev_split(
examples: Sequence[BirdExample],
*,
n: int,
seed: int = 0,
) -> list[BirdExample]:
"""Deterministic sample of `n` examples with stable-prefix property.
Implementation: shuffle the pool once with `random.Random(seed)` and
take the first `n`. This guarantees that for the same seed,
`dev_split(..., n=k1)` is a prefix of `dev_split(..., n=k2)` whenever
`k1 <= k2` β so growing the eval slice (50 β 100 β 200) re-uses every
cached prompt from the smaller run instead of re-rolling.
Result is sorted by question_id for reader stability (the underlying
shuffle is unordered, but eval reports want stable IDs).
"""
if n <= 0:
return []
pool = list(examples)
if n >= len(pool):
return sorted(pool, key=lambda e: e.question_id)
rng = random.Random(seed)
shuffled = pool[:]
rng.shuffle(shuffled)
chosen = shuffled[:n]
return sorted(chosen, key=lambda e: e.question_id)
def extract_gold_tables(sql: str) -> list[str]:
"""Walk the SQL AST and collect every base-table reference.
Used by Schema Recall@k. Captures tables referenced anywhere in the
query β FROM, JOIN, correlated subqueries inside WHERE / SELECT,
IN-list subqueries, set operations, etc. CTE names defined via
``WITH ... AS (...)`` are excluded because they shadow base tables
in scope and would inflate recall against the schema_chunks index.
Falls back to the FROM/JOIN regex if sqlglot can't parse the SQL β
BIRD ships a small fraction of dialect-specific quirks that even
the lenient parser may reject; better to under-count than crash.
"""
try:
tree = sqlglot.parse_one(sql, read="sqlite")
except sqlglot.errors.ParseError:
return _extract_via_regex(sql)
if tree is None:
return _extract_via_regex(sql)
# CTE names live in a WITH block above the body β collect them so we
# can drop matches that point at a CTE alias rather than a base table.
cte_names: set[str] = {
cte.alias_or_name.lower() for cte in tree.find_all(exp.CTE) if cte.alias_or_name
}
tables: list[str] = []
seen: set[str] = set()
for node in tree.find_all(exp.Table):
# Walk up to detect tables that are themselves the alias side of
# a CTE definition (the body of WITH x AS (...) β sqlglot models
# the inner SELECT's tables here, which we still want; only skip
# references whose .name matches a CTE alias).
name = node.name
if not name:
continue
key = name.lower()
if key in cte_names:
continue
if key in seen:
continue
seen.add(key)
tables.append(name)
if not tables:
return _extract_via_regex(sql)
return tables
def _extract_via_regex(sql: str) -> list[str]:
"""Legacy regex-based fallback for the ~1% of SQLs sqlglot can't parse."""
tables: list[str] = []
seen: set[str] = set()
for match in _TABLE_RE.finditer(sql):
table = match.group(2)
key = table.lower()
if key in seen:
continue
seen.add(key)
tables.append(table)
return tables
def is_in_dev_split(question: str, dev_examples: Iterable[BirdExample]) -> bool:
"""Helper for the leakage-check CI test (`test_no_dev_in_fewshot`).
Returns True iff `question` text exactly matches any dev example. Exact
match is strict on purpose β paraphrases are NOT considered leakage,
only verbatim copies (which is the actual risk when curating a few-shot
pool from public sources).
"""
needle = question.strip().lower()
return any(ex.question.strip().lower() == needle for ex in dev_examples)
def _to_example(item: dict[str, Any], *, dialect: Dialect) -> BirdExample:
difficulty = str(item.get("difficulty", "moderate"))
if difficulty not in ("simple", "moderate", "challenging"):
difficulty = "moderate"
return BirdExample(
question_id=int(item["question_id"]),
db_id=str(item["db_id"]),
question=str(item["question"]),
evidence=str(item.get("evidence", "")),
sql=str(item["SQL"]),
difficulty=difficulty, # type: ignore[arg-type]
dialect=dialect,
)
|